• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2016-2021 The Brenwill Workshop Ltd.
3  * SPDX-License-Identifier: Apache-2.0 OR MIT
4  *
5  * Licensed under the Apache License, Version 2.0 (the "License");
6  * you may not use this file except in compliance with the License.
7  * You may obtain a copy of the License at
8  *
9  *     http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 
18 /*
19  * At your option, you may choose to accept this material under either:
20  *  1. The Apache License, Version 2.0, found at <http://www.apache.org/licenses/LICENSE-2.0>, or
21  *  2. The MIT License, found at <http://opensource.org/licenses/MIT>.
22  */
23 
24 #include "spirv_msl.hpp"
25 #include "GLSL.std.450.h"
26 
27 #include <algorithm>
28 #include <assert.h>
29 #include <numeric>
30 
31 using namespace spv;
32 using namespace SPIRV_CROSS_NAMESPACE;
33 using namespace std;
34 
35 static const uint32_t k_unknown_location = ~0u;
36 static const uint32_t k_unknown_component = ~0u;
37 static const char *force_inline = "static inline __attribute__((always_inline))";
38 
CompilerMSL(std::vector<uint32_t> spirv_)39 CompilerMSL::CompilerMSL(std::vector<uint32_t> spirv_)
40     : CompilerGLSL(move(spirv_))
41 {
42 }
43 
CompilerMSL(const uint32_t * ir_,size_t word_count)44 CompilerMSL::CompilerMSL(const uint32_t *ir_, size_t word_count)
45     : CompilerGLSL(ir_, word_count)
46 {
47 }
48 
CompilerMSL(const ParsedIR & ir_)49 CompilerMSL::CompilerMSL(const ParsedIR &ir_)
50     : CompilerGLSL(ir_)
51 {
52 }
53 
CompilerMSL(ParsedIR && ir_)54 CompilerMSL::CompilerMSL(ParsedIR &&ir_)
55     : CompilerGLSL(std::move(ir_))
56 {
57 }
58 
add_msl_shader_input(const MSLShaderInput & si)59 void CompilerMSL::add_msl_shader_input(const MSLShaderInput &si)
60 {
61 	inputs_by_location[si.location] = si;
62 	if (si.builtin != BuiltInMax && !inputs_by_builtin.count(si.builtin))
63 		inputs_by_builtin[si.builtin] = si;
64 }
65 
add_msl_resource_binding(const MSLResourceBinding & binding)66 void CompilerMSL::add_msl_resource_binding(const MSLResourceBinding &binding)
67 {
68 	StageSetBinding tuple = { binding.stage, binding.desc_set, binding.binding };
69 	resource_bindings[tuple] = { binding, false };
70 
71 	// If we might need to pad argument buffer members to positionally align
72 	// arg buffer indexes, also maintain a lookup by argument buffer index.
73 	if (msl_options.pad_argument_buffer_resources)
74 	{
75 		StageSetBinding arg_idx_tuple = { binding.stage, binding.desc_set, k_unknown_component };
76 
77 #define ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP(rez) \
78 	arg_idx_tuple.binding = binding.msl_##rez; \
79 	resource_arg_buff_idx_to_binding_number[arg_idx_tuple] = binding.binding
80 
81 		switch (binding.basetype)
82 		{
83 		case SPIRType::Void:
84 		case SPIRType::Boolean:
85 		case SPIRType::SByte:
86 		case SPIRType::UByte:
87 		case SPIRType::Short:
88 		case SPIRType::UShort:
89 		case SPIRType::Int:
90 		case SPIRType::UInt:
91 		case SPIRType::Int64:
92 		case SPIRType::UInt64:
93 		case SPIRType::AtomicCounter:
94 		case SPIRType::Half:
95 		case SPIRType::Float:
96 		case SPIRType::Double:
97 			ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP(buffer);
98 			break;
99 		case SPIRType::Image:
100 			ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP(texture);
101 			break;
102 		case SPIRType::Sampler:
103 			ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP(sampler);
104 			break;
105 		case SPIRType::SampledImage:
106 			ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP(texture);
107 			ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP(sampler);
108 			break;
109 		default:
110 			SPIRV_CROSS_THROW("Unexpected argument buffer resource base type. When padding argument buffer elements, "
111 			                  "all descriptor set resources must be supplied with a base type by the app.");
112 		}
113 #undef ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP
114 	}
115 }
116 
add_dynamic_buffer(uint32_t desc_set,uint32_t binding,uint32_t index)117 void CompilerMSL::add_dynamic_buffer(uint32_t desc_set, uint32_t binding, uint32_t index)
118 {
119 	SetBindingPair pair = { desc_set, binding };
120 	buffers_requiring_dynamic_offset[pair] = { index, 0 };
121 }
122 
add_inline_uniform_block(uint32_t desc_set,uint32_t binding)123 void CompilerMSL::add_inline_uniform_block(uint32_t desc_set, uint32_t binding)
124 {
125 	SetBindingPair pair = { desc_set, binding };
126 	inline_uniform_blocks.insert(pair);
127 }
128 
add_discrete_descriptor_set(uint32_t desc_set)129 void CompilerMSL::add_discrete_descriptor_set(uint32_t desc_set)
130 {
131 	if (desc_set < kMaxArgumentBuffers)
132 		argument_buffer_discrete_mask |= 1u << desc_set;
133 }
134 
set_argument_buffer_device_address_space(uint32_t desc_set,bool device_storage)135 void CompilerMSL::set_argument_buffer_device_address_space(uint32_t desc_set, bool device_storage)
136 {
137 	if (desc_set < kMaxArgumentBuffers)
138 	{
139 		if (device_storage)
140 			argument_buffer_device_storage_mask |= 1u << desc_set;
141 		else
142 			argument_buffer_device_storage_mask &= ~(1u << desc_set);
143 	}
144 }
145 
is_msl_shader_input_used(uint32_t location)146 bool CompilerMSL::is_msl_shader_input_used(uint32_t location)
147 {
148 	// Don't report internal location allocations to app.
149 	return location_inputs_in_use.count(location) != 0 &&
150 	       location_inputs_in_use_fallback.count(location) == 0;
151 }
152 
get_automatic_builtin_input_location(spv::BuiltIn builtin) const153 uint32_t CompilerMSL::get_automatic_builtin_input_location(spv::BuiltIn builtin) const
154 {
155 	auto itr = builtin_to_automatic_input_location.find(builtin);
156 	if (itr == builtin_to_automatic_input_location.end())
157 		return k_unknown_location;
158 	else
159 		return itr->second;
160 }
161 
is_msl_resource_binding_used(ExecutionModel model,uint32_t desc_set,uint32_t binding) const162 bool CompilerMSL::is_msl_resource_binding_used(ExecutionModel model, uint32_t desc_set, uint32_t binding) const
163 {
164 	StageSetBinding tuple = { model, desc_set, binding };
165 	auto itr = resource_bindings.find(tuple);
166 	return itr != end(resource_bindings) && itr->second.second;
167 }
168 
169 // Returns the size of the array of resources used by the variable with the specified id.
170 // The returned value is retrieved from the resource binding added using add_msl_resource_binding().
get_resource_array_size(uint32_t id) const171 uint32_t CompilerMSL::get_resource_array_size(uint32_t id) const
172 {
173 	StageSetBinding tuple = { get_entry_point().model, get_decoration(id, DecorationDescriptorSet),
174 		                      get_decoration(id, DecorationBinding) };
175 	auto itr = resource_bindings.find(tuple);
176 	return itr != end(resource_bindings) ? itr->second.first.count : 0;
177 }
178 
get_automatic_msl_resource_binding(uint32_t id) const179 uint32_t CompilerMSL::get_automatic_msl_resource_binding(uint32_t id) const
180 {
181 	return get_extended_decoration(id, SPIRVCrossDecorationResourceIndexPrimary);
182 }
183 
get_automatic_msl_resource_binding_secondary(uint32_t id) const184 uint32_t CompilerMSL::get_automatic_msl_resource_binding_secondary(uint32_t id) const
185 {
186 	return get_extended_decoration(id, SPIRVCrossDecorationResourceIndexSecondary);
187 }
188 
get_automatic_msl_resource_binding_tertiary(uint32_t id) const189 uint32_t CompilerMSL::get_automatic_msl_resource_binding_tertiary(uint32_t id) const
190 {
191 	return get_extended_decoration(id, SPIRVCrossDecorationResourceIndexTertiary);
192 }
193 
get_automatic_msl_resource_binding_quaternary(uint32_t id) const194 uint32_t CompilerMSL::get_automatic_msl_resource_binding_quaternary(uint32_t id) const
195 {
196 	return get_extended_decoration(id, SPIRVCrossDecorationResourceIndexQuaternary);
197 }
198 
set_fragment_output_components(uint32_t location,uint32_t components)199 void CompilerMSL::set_fragment_output_components(uint32_t location, uint32_t components)
200 {
201 	fragment_output_components[location] = components;
202 }
203 
builtin_translates_to_nonarray(spv::BuiltIn builtin) const204 bool CompilerMSL::builtin_translates_to_nonarray(spv::BuiltIn builtin) const
205 {
206 	return (builtin == BuiltInSampleMask);
207 }
208 
build_implicit_builtins()209 void CompilerMSL::build_implicit_builtins()
210 {
211 	bool need_sample_pos = active_input_builtins.get(BuiltInSamplePosition);
212 	bool need_vertex_params = capture_output_to_buffer && get_execution_model() == ExecutionModelVertex &&
213 	                          !msl_options.vertex_for_tessellation;
214 	bool need_tesc_params = get_execution_model() == ExecutionModelTessellationControl;
215 	bool need_subgroup_mask =
216 	    active_input_builtins.get(BuiltInSubgroupEqMask) || active_input_builtins.get(BuiltInSubgroupGeMask) ||
217 	    active_input_builtins.get(BuiltInSubgroupGtMask) || active_input_builtins.get(BuiltInSubgroupLeMask) ||
218 	    active_input_builtins.get(BuiltInSubgroupLtMask);
219 	bool need_subgroup_ge_mask = !msl_options.is_ios() && (active_input_builtins.get(BuiltInSubgroupGeMask) ||
220 	                                                       active_input_builtins.get(BuiltInSubgroupGtMask));
221 	bool need_multiview = get_execution_model() == ExecutionModelVertex && !msl_options.view_index_from_device_index &&
222 	                      msl_options.multiview_layered_rendering &&
223 	                      (msl_options.multiview || active_input_builtins.get(BuiltInViewIndex));
224 	bool need_dispatch_base =
225 	    msl_options.dispatch_base && get_execution_model() == ExecutionModelGLCompute &&
226 	    (active_input_builtins.get(BuiltInWorkgroupId) || active_input_builtins.get(BuiltInGlobalInvocationId));
227 	bool need_grid_params = get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation;
228 	bool need_vertex_base_params =
229 	    need_grid_params &&
230 	    (active_input_builtins.get(BuiltInVertexId) || active_input_builtins.get(BuiltInVertexIndex) ||
231 	     active_input_builtins.get(BuiltInBaseVertex) || active_input_builtins.get(BuiltInInstanceId) ||
232 	     active_input_builtins.get(BuiltInInstanceIndex) || active_input_builtins.get(BuiltInBaseInstance));
233 	bool need_local_invocation_index = msl_options.emulate_subgroups && active_input_builtins.get(BuiltInSubgroupId);
234 	bool need_workgroup_size = msl_options.emulate_subgroups && active_input_builtins.get(BuiltInNumSubgroups);
235 
236 	if (need_subpass_input || need_sample_pos || need_subgroup_mask || need_vertex_params || need_tesc_params ||
237 	    need_multiview || need_dispatch_base || need_vertex_base_params || need_grid_params || needs_sample_id ||
238 	    needs_subgroup_invocation_id || needs_subgroup_size || has_additional_fixed_sample_mask() || need_local_invocation_index ||
239 	    need_workgroup_size)
240 	{
241 		bool has_frag_coord = false;
242 		bool has_sample_id = false;
243 		bool has_vertex_idx = false;
244 		bool has_base_vertex = false;
245 		bool has_instance_idx = false;
246 		bool has_base_instance = false;
247 		bool has_invocation_id = false;
248 		bool has_primitive_id = false;
249 		bool has_subgroup_invocation_id = false;
250 		bool has_subgroup_size = false;
251 		bool has_view_idx = false;
252 		bool has_layer = false;
253 		bool has_local_invocation_index = false;
254 		bool has_workgroup_size = false;
255 		uint32_t workgroup_id_type = 0;
256 
257 		ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
258 			if (var.storage != StorageClassInput && var.storage != StorageClassOutput)
259 				return;
260 			if (!interface_variable_exists_in_entry_point(var.self))
261 				return;
262 			if (!has_decoration(var.self, DecorationBuiltIn))
263 				return;
264 
265 			BuiltIn builtin = ir.meta[var.self].decoration.builtin_type;
266 
267 			if (var.storage == StorageClassOutput)
268 			{
269 				if (has_additional_fixed_sample_mask() && builtin == BuiltInSampleMask)
270 				{
271 					builtin_sample_mask_id = var.self;
272 					mark_implicit_builtin(StorageClassOutput, BuiltInSampleMask, var.self);
273 					does_shader_write_sample_mask = true;
274 				}
275 			}
276 
277 			if (var.storage != StorageClassInput)
278 				return;
279 
280 			// Use Metal's native frame-buffer fetch API for subpass inputs.
281 			if (need_subpass_input && (!msl_options.use_framebuffer_fetch_subpasses))
282 			{
283 				switch (builtin)
284 				{
285 				case BuiltInFragCoord:
286 					mark_implicit_builtin(StorageClassInput, BuiltInFragCoord, var.self);
287 					builtin_frag_coord_id = var.self;
288 					has_frag_coord = true;
289 					break;
290 				case BuiltInLayer:
291 					if (!msl_options.arrayed_subpass_input || msl_options.multiview)
292 						break;
293 					mark_implicit_builtin(StorageClassInput, BuiltInLayer, var.self);
294 					builtin_layer_id = var.self;
295 					has_layer = true;
296 					break;
297 				case BuiltInViewIndex:
298 					if (!msl_options.multiview)
299 						break;
300 					mark_implicit_builtin(StorageClassInput, BuiltInViewIndex, var.self);
301 					builtin_view_idx_id = var.self;
302 					has_view_idx = true;
303 					break;
304 				default:
305 					break;
306 				}
307 			}
308 
309 			if ((need_sample_pos || needs_sample_id) && builtin == BuiltInSampleId)
310 			{
311 				builtin_sample_id_id = var.self;
312 				mark_implicit_builtin(StorageClassInput, BuiltInSampleId, var.self);
313 				has_sample_id = true;
314 			}
315 
316 			if (need_vertex_params)
317 			{
318 				switch (builtin)
319 				{
320 				case BuiltInVertexIndex:
321 					builtin_vertex_idx_id = var.self;
322 					mark_implicit_builtin(StorageClassInput, BuiltInVertexIndex, var.self);
323 					has_vertex_idx = true;
324 					break;
325 				case BuiltInBaseVertex:
326 					builtin_base_vertex_id = var.self;
327 					mark_implicit_builtin(StorageClassInput, BuiltInBaseVertex, var.self);
328 					has_base_vertex = true;
329 					break;
330 				case BuiltInInstanceIndex:
331 					builtin_instance_idx_id = var.self;
332 					mark_implicit_builtin(StorageClassInput, BuiltInInstanceIndex, var.self);
333 					has_instance_idx = true;
334 					break;
335 				case BuiltInBaseInstance:
336 					builtin_base_instance_id = var.self;
337 					mark_implicit_builtin(StorageClassInput, BuiltInBaseInstance, var.self);
338 					has_base_instance = true;
339 					break;
340 				default:
341 					break;
342 				}
343 			}
344 
345 			if (need_tesc_params)
346 			{
347 				switch (builtin)
348 				{
349 				case BuiltInInvocationId:
350 					builtin_invocation_id_id = var.self;
351 					mark_implicit_builtin(StorageClassInput, BuiltInInvocationId, var.self);
352 					has_invocation_id = true;
353 					break;
354 				case BuiltInPrimitiveId:
355 					builtin_primitive_id_id = var.self;
356 					mark_implicit_builtin(StorageClassInput, BuiltInPrimitiveId, var.self);
357 					has_primitive_id = true;
358 					break;
359 				default:
360 					break;
361 				}
362 			}
363 
364 			if ((need_subgroup_mask || needs_subgroup_invocation_id) && builtin == BuiltInSubgroupLocalInvocationId)
365 			{
366 				builtin_subgroup_invocation_id_id = var.self;
367 				mark_implicit_builtin(StorageClassInput, BuiltInSubgroupLocalInvocationId, var.self);
368 				has_subgroup_invocation_id = true;
369 			}
370 
371 			if ((need_subgroup_ge_mask || needs_subgroup_size) && builtin == BuiltInSubgroupSize)
372 			{
373 				builtin_subgroup_size_id = var.self;
374 				mark_implicit_builtin(StorageClassInput, BuiltInSubgroupSize, var.self);
375 				has_subgroup_size = true;
376 			}
377 
378 			if (need_multiview)
379 			{
380 				switch (builtin)
381 				{
382 				case BuiltInInstanceIndex:
383 					// The view index here is derived from the instance index.
384 					builtin_instance_idx_id = var.self;
385 					mark_implicit_builtin(StorageClassInput, BuiltInInstanceIndex, var.self);
386 					has_instance_idx = true;
387 					break;
388 				case BuiltInBaseInstance:
389 					// If a non-zero base instance is used, we need to adjust for it when calculating the view index.
390 					builtin_base_instance_id = var.self;
391 					mark_implicit_builtin(StorageClassInput, BuiltInBaseInstance, var.self);
392 					has_base_instance = true;
393 					break;
394 				case BuiltInViewIndex:
395 					builtin_view_idx_id = var.self;
396 					mark_implicit_builtin(StorageClassInput, BuiltInViewIndex, var.self);
397 					has_view_idx = true;
398 					break;
399 				default:
400 					break;
401 				}
402 			}
403 
404 			if (need_local_invocation_index && builtin == BuiltInLocalInvocationIndex)
405 			{
406 				builtin_local_invocation_index_id = var.self;
407 				mark_implicit_builtin(StorageClassInput, BuiltInLocalInvocationIndex, var.self);
408 				has_local_invocation_index = true;
409 			}
410 
411 			if (need_workgroup_size && builtin == BuiltInLocalInvocationId)
412 			{
413 				builtin_workgroup_size_id = var.self;
414 				mark_implicit_builtin(StorageClassInput, BuiltInWorkgroupSize, var.self);
415 				has_workgroup_size = true;
416 			}
417 
418 			// The base workgroup needs to have the same type and vector size
419 			// as the workgroup or invocation ID, so keep track of the type that
420 			// was used.
421 			if (need_dispatch_base && workgroup_id_type == 0 &&
422 			    (builtin == BuiltInWorkgroupId || builtin == BuiltInGlobalInvocationId))
423 				workgroup_id_type = var.basetype;
424 		});
425 
426 		// Use Metal's native frame-buffer fetch API for subpass inputs.
427 		if ((!has_frag_coord || (msl_options.multiview && !has_view_idx) ||
428 		     (msl_options.arrayed_subpass_input && !msl_options.multiview && !has_layer)) &&
429 		    (!msl_options.use_framebuffer_fetch_subpasses) && need_subpass_input)
430 		{
431 			if (!has_frag_coord)
432 			{
433 				uint32_t offset = ir.increase_bound_by(3);
434 				uint32_t type_id = offset;
435 				uint32_t type_ptr_id = offset + 1;
436 				uint32_t var_id = offset + 2;
437 
438 				// Create gl_FragCoord.
439 				SPIRType vec4_type;
440 				vec4_type.basetype = SPIRType::Float;
441 				vec4_type.width = 32;
442 				vec4_type.vecsize = 4;
443 				set<SPIRType>(type_id, vec4_type);
444 
445 				SPIRType vec4_type_ptr;
446 				vec4_type_ptr = vec4_type;
447 				vec4_type_ptr.pointer = true;
448 				vec4_type_ptr.pointer_depth++;
449 				vec4_type_ptr.parent_type = type_id;
450 				vec4_type_ptr.storage = StorageClassInput;
451 				auto &ptr_type = set<SPIRType>(type_ptr_id, vec4_type_ptr);
452 				ptr_type.self = type_id;
453 
454 				set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
455 				set_decoration(var_id, DecorationBuiltIn, BuiltInFragCoord);
456 				builtin_frag_coord_id = var_id;
457 				mark_implicit_builtin(StorageClassInput, BuiltInFragCoord, var_id);
458 			}
459 
460 			if (!has_layer && msl_options.arrayed_subpass_input && !msl_options.multiview)
461 			{
462 				uint32_t offset = ir.increase_bound_by(2);
463 				uint32_t type_ptr_id = offset;
464 				uint32_t var_id = offset + 1;
465 
466 				// Create gl_Layer.
467 				SPIRType uint_type_ptr;
468 				uint_type_ptr = get_uint_type();
469 				uint_type_ptr.pointer = true;
470 				uint_type_ptr.pointer_depth++;
471 				uint_type_ptr.parent_type = get_uint_type_id();
472 				uint_type_ptr.storage = StorageClassInput;
473 				auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
474 				ptr_type.self = get_uint_type_id();
475 
476 				set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
477 				set_decoration(var_id, DecorationBuiltIn, BuiltInLayer);
478 				builtin_layer_id = var_id;
479 				mark_implicit_builtin(StorageClassInput, BuiltInLayer, var_id);
480 			}
481 
482 			if (!has_view_idx && msl_options.multiview)
483 			{
484 				uint32_t offset = ir.increase_bound_by(2);
485 				uint32_t type_ptr_id = offset;
486 				uint32_t var_id = offset + 1;
487 
488 				// Create gl_ViewIndex.
489 				SPIRType uint_type_ptr;
490 				uint_type_ptr = get_uint_type();
491 				uint_type_ptr.pointer = true;
492 				uint_type_ptr.pointer_depth++;
493 				uint_type_ptr.parent_type = get_uint_type_id();
494 				uint_type_ptr.storage = StorageClassInput;
495 				auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
496 				ptr_type.self = get_uint_type_id();
497 
498 				set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
499 				set_decoration(var_id, DecorationBuiltIn, BuiltInViewIndex);
500 				builtin_view_idx_id = var_id;
501 				mark_implicit_builtin(StorageClassInput, BuiltInViewIndex, var_id);
502 			}
503 		}
504 
505 		if (!has_sample_id && (need_sample_pos || needs_sample_id))
506 		{
507 			uint32_t offset = ir.increase_bound_by(2);
508 			uint32_t type_ptr_id = offset;
509 			uint32_t var_id = offset + 1;
510 
511 			// Create gl_SampleID.
512 			SPIRType uint_type_ptr;
513 			uint_type_ptr = get_uint_type();
514 			uint_type_ptr.pointer = true;
515 			uint_type_ptr.pointer_depth++;
516 			uint_type_ptr.parent_type = get_uint_type_id();
517 			uint_type_ptr.storage = StorageClassInput;
518 			auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
519 			ptr_type.self = get_uint_type_id();
520 
521 			set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
522 			set_decoration(var_id, DecorationBuiltIn, BuiltInSampleId);
523 			builtin_sample_id_id = var_id;
524 			mark_implicit_builtin(StorageClassInput, BuiltInSampleId, var_id);
525 		}
526 
527 		if ((need_vertex_params && (!has_vertex_idx || !has_base_vertex || !has_instance_idx || !has_base_instance)) ||
528 		    (need_multiview && (!has_instance_idx || !has_base_instance || !has_view_idx)))
529 		{
530 			uint32_t type_ptr_id = ir.increase_bound_by(1);
531 
532 			SPIRType uint_type_ptr;
533 			uint_type_ptr = get_uint_type();
534 			uint_type_ptr.pointer = true;
535 			uint_type_ptr.pointer_depth++;
536 			uint_type_ptr.parent_type = get_uint_type_id();
537 			uint_type_ptr.storage = StorageClassInput;
538 			auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
539 			ptr_type.self = get_uint_type_id();
540 
541 			if (need_vertex_params && !has_vertex_idx)
542 			{
543 				uint32_t var_id = ir.increase_bound_by(1);
544 
545 				// Create gl_VertexIndex.
546 				set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
547 				set_decoration(var_id, DecorationBuiltIn, BuiltInVertexIndex);
548 				builtin_vertex_idx_id = var_id;
549 				mark_implicit_builtin(StorageClassInput, BuiltInVertexIndex, var_id);
550 			}
551 
552 			if (need_vertex_params && !has_base_vertex)
553 			{
554 				uint32_t var_id = ir.increase_bound_by(1);
555 
556 				// Create gl_BaseVertex.
557 				set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
558 				set_decoration(var_id, DecorationBuiltIn, BuiltInBaseVertex);
559 				builtin_base_vertex_id = var_id;
560 				mark_implicit_builtin(StorageClassInput, BuiltInBaseVertex, var_id);
561 			}
562 
563 			if (!has_instance_idx) // Needed by both multiview and tessellation
564 			{
565 				uint32_t var_id = ir.increase_bound_by(1);
566 
567 				// Create gl_InstanceIndex.
568 				set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
569 				set_decoration(var_id, DecorationBuiltIn, BuiltInInstanceIndex);
570 				builtin_instance_idx_id = var_id;
571 				mark_implicit_builtin(StorageClassInput, BuiltInInstanceIndex, var_id);
572 			}
573 
574 			if (!has_base_instance) // Needed by both multiview and tessellation
575 			{
576 				uint32_t var_id = ir.increase_bound_by(1);
577 
578 				// Create gl_BaseInstance.
579 				set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
580 				set_decoration(var_id, DecorationBuiltIn, BuiltInBaseInstance);
581 				builtin_base_instance_id = var_id;
582 				mark_implicit_builtin(StorageClassInput, BuiltInBaseInstance, var_id);
583 			}
584 
585 			if (need_multiview)
586 			{
587 				// Multiview shaders are not allowed to write to gl_Layer, ostensibly because
588 				// it is implicitly written from gl_ViewIndex, but we have to do that explicitly.
589 				// Note that we can't just abuse gl_ViewIndex for this purpose: it's an input, but
590 				// gl_Layer is an output in vertex-pipeline shaders.
591 				uint32_t type_ptr_out_id = ir.increase_bound_by(2);
592 				SPIRType uint_type_ptr_out;
593 				uint_type_ptr_out = get_uint_type();
594 				uint_type_ptr_out.pointer = true;
595 				uint_type_ptr_out.pointer_depth++;
596 				uint_type_ptr_out.parent_type = get_uint_type_id();
597 				uint_type_ptr_out.storage = StorageClassOutput;
598 				auto &ptr_out_type = set<SPIRType>(type_ptr_out_id, uint_type_ptr_out);
599 				ptr_out_type.self = get_uint_type_id();
600 				uint32_t var_id = type_ptr_out_id + 1;
601 				set<SPIRVariable>(var_id, type_ptr_out_id, StorageClassOutput);
602 				set_decoration(var_id, DecorationBuiltIn, BuiltInLayer);
603 				builtin_layer_id = var_id;
604 				mark_implicit_builtin(StorageClassOutput, BuiltInLayer, var_id);
605 			}
606 
607 			if (need_multiview && !has_view_idx)
608 			{
609 				uint32_t var_id = ir.increase_bound_by(1);
610 
611 				// Create gl_ViewIndex.
612 				set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
613 				set_decoration(var_id, DecorationBuiltIn, BuiltInViewIndex);
614 				builtin_view_idx_id = var_id;
615 				mark_implicit_builtin(StorageClassInput, BuiltInViewIndex, var_id);
616 			}
617 		}
618 
619 		if ((need_tesc_params && (msl_options.multi_patch_workgroup || !has_invocation_id || !has_primitive_id)) ||
620 		    need_grid_params)
621 		{
622 			uint32_t type_ptr_id = ir.increase_bound_by(1);
623 
624 			SPIRType uint_type_ptr;
625 			uint_type_ptr = get_uint_type();
626 			uint_type_ptr.pointer = true;
627 			uint_type_ptr.pointer_depth++;
628 			uint_type_ptr.parent_type = get_uint_type_id();
629 			uint_type_ptr.storage = StorageClassInput;
630 			auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
631 			ptr_type.self = get_uint_type_id();
632 
633 			if (msl_options.multi_patch_workgroup || need_grid_params)
634 			{
635 				uint32_t var_id = ir.increase_bound_by(1);
636 
637 				// Create gl_GlobalInvocationID.
638 				set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
639 				set_decoration(var_id, DecorationBuiltIn, BuiltInGlobalInvocationId);
640 				builtin_invocation_id_id = var_id;
641 				mark_implicit_builtin(StorageClassInput, BuiltInGlobalInvocationId, var_id);
642 			}
643 			else if (need_tesc_params && !has_invocation_id)
644 			{
645 				uint32_t var_id = ir.increase_bound_by(1);
646 
647 				// Create gl_InvocationID.
648 				set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
649 				set_decoration(var_id, DecorationBuiltIn, BuiltInInvocationId);
650 				builtin_invocation_id_id = var_id;
651 				mark_implicit_builtin(StorageClassInput, BuiltInInvocationId, var_id);
652 			}
653 
654 			if (need_tesc_params && !has_primitive_id)
655 			{
656 				uint32_t var_id = ir.increase_bound_by(1);
657 
658 				// Create gl_PrimitiveID.
659 				set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
660 				set_decoration(var_id, DecorationBuiltIn, BuiltInPrimitiveId);
661 				builtin_primitive_id_id = var_id;
662 				mark_implicit_builtin(StorageClassInput, BuiltInPrimitiveId, var_id);
663 			}
664 
665 			if (need_grid_params)
666 			{
667 				uint32_t var_id = ir.increase_bound_by(1);
668 
669 				set<SPIRVariable>(var_id, build_extended_vector_type(get_uint_type_id(), 3), StorageClassInput);
670 				set_extended_decoration(var_id, SPIRVCrossDecorationBuiltInStageInputSize);
671 				get_entry_point().interface_variables.push_back(var_id);
672 				set_name(var_id, "spvStageInputSize");
673 				builtin_stage_input_size_id = var_id;
674 			}
675 		}
676 
677 		if (!has_subgroup_invocation_id && (need_subgroup_mask || needs_subgroup_invocation_id))
678 		{
679 			uint32_t offset = ir.increase_bound_by(2);
680 			uint32_t type_ptr_id = offset;
681 			uint32_t var_id = offset + 1;
682 
683 			// Create gl_SubgroupInvocationID.
684 			SPIRType uint_type_ptr;
685 			uint_type_ptr = get_uint_type();
686 			uint_type_ptr.pointer = true;
687 			uint_type_ptr.pointer_depth++;
688 			uint_type_ptr.parent_type = get_uint_type_id();
689 			uint_type_ptr.storage = StorageClassInput;
690 			auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
691 			ptr_type.self = get_uint_type_id();
692 
693 			set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
694 			set_decoration(var_id, DecorationBuiltIn, BuiltInSubgroupLocalInvocationId);
695 			builtin_subgroup_invocation_id_id = var_id;
696 			mark_implicit_builtin(StorageClassInput, BuiltInSubgroupLocalInvocationId, var_id);
697 		}
698 
699 		if (!has_subgroup_size && (need_subgroup_ge_mask || needs_subgroup_size))
700 		{
701 			uint32_t offset = ir.increase_bound_by(2);
702 			uint32_t type_ptr_id = offset;
703 			uint32_t var_id = offset + 1;
704 
705 			// Create gl_SubgroupSize.
706 			SPIRType uint_type_ptr;
707 			uint_type_ptr = get_uint_type();
708 			uint_type_ptr.pointer = true;
709 			uint_type_ptr.pointer_depth++;
710 			uint_type_ptr.parent_type = get_uint_type_id();
711 			uint_type_ptr.storage = StorageClassInput;
712 			auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
713 			ptr_type.self = get_uint_type_id();
714 
715 			set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
716 			set_decoration(var_id, DecorationBuiltIn, BuiltInSubgroupSize);
717 			builtin_subgroup_size_id = var_id;
718 			mark_implicit_builtin(StorageClassInput, BuiltInSubgroupSize, var_id);
719 		}
720 
721 		if (need_dispatch_base || need_vertex_base_params)
722 		{
723 			if (workgroup_id_type == 0)
724 				workgroup_id_type = build_extended_vector_type(get_uint_type_id(), 3);
725 			uint32_t var_id;
726 			if (msl_options.supports_msl_version(1, 2))
727 			{
728 				// If we have MSL 1.2, we can (ab)use the [[grid_origin]] builtin
729 				// to convey this information and save a buffer slot.
730 				uint32_t offset = ir.increase_bound_by(1);
731 				var_id = offset;
732 
733 				set<SPIRVariable>(var_id, workgroup_id_type, StorageClassInput);
734 				set_extended_decoration(var_id, SPIRVCrossDecorationBuiltInDispatchBase);
735 				get_entry_point().interface_variables.push_back(var_id);
736 			}
737 			else
738 			{
739 				// Otherwise, we need to fall back to a good ol' fashioned buffer.
740 				uint32_t offset = ir.increase_bound_by(2);
741 				var_id = offset;
742 				uint32_t type_id = offset + 1;
743 
744 				SPIRType var_type = get<SPIRType>(workgroup_id_type);
745 				var_type.storage = StorageClassUniform;
746 				set<SPIRType>(type_id, var_type);
747 
748 				set<SPIRVariable>(var_id, type_id, StorageClassUniform);
749 				// This should never match anything.
750 				set_decoration(var_id, DecorationDescriptorSet, ~(5u));
751 				set_decoration(var_id, DecorationBinding, msl_options.indirect_params_buffer_index);
752 				set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary,
753 				                        msl_options.indirect_params_buffer_index);
754 			}
755 			set_name(var_id, "spvDispatchBase");
756 			builtin_dispatch_base_id = var_id;
757 		}
758 
759 		if (has_additional_fixed_sample_mask() && !does_shader_write_sample_mask)
760 		{
761 			uint32_t offset = ir.increase_bound_by(2);
762 			uint32_t var_id = offset + 1;
763 
764 			// Create gl_SampleMask.
765 			SPIRType uint_type_ptr_out;
766 			uint_type_ptr_out = get_uint_type();
767 			uint_type_ptr_out.pointer = true;
768 			uint_type_ptr_out.pointer_depth++;
769 			uint_type_ptr_out.parent_type = get_uint_type_id();
770 			uint_type_ptr_out.storage = StorageClassOutput;
771 
772 			auto &ptr_out_type = set<SPIRType>(offset, uint_type_ptr_out);
773 			ptr_out_type.self = get_uint_type_id();
774 			set<SPIRVariable>(var_id, offset, StorageClassOutput);
775 			set_decoration(var_id, DecorationBuiltIn, BuiltInSampleMask);
776 			builtin_sample_mask_id = var_id;
777 			mark_implicit_builtin(StorageClassOutput, BuiltInSampleMask, var_id);
778 		}
779 
780 		if (need_local_invocation_index && !has_local_invocation_index)
781 		{
782 			uint32_t offset = ir.increase_bound_by(2);
783 			uint32_t type_ptr_id = offset;
784 			uint32_t var_id = offset + 1;
785 
786 			// Create gl_LocalInvocationIndex.
787 			SPIRType uint_type_ptr;
788 			uint_type_ptr = get_uint_type();
789 			uint_type_ptr.pointer = true;
790 			uint_type_ptr.pointer_depth++;
791 			uint_type_ptr.parent_type = get_uint_type_id();
792 			uint_type_ptr.storage = StorageClassInput;
793 
794 			auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
795 			ptr_type.self = get_uint_type_id();
796 			set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
797 			set_decoration(var_id, DecorationBuiltIn, BuiltInLocalInvocationIndex);
798 			builtin_local_invocation_index_id = var_id;
799 			mark_implicit_builtin(StorageClassInput, BuiltInLocalInvocationIndex, var_id);
800 		}
801 
802 		if (need_workgroup_size && !has_workgroup_size)
803 		{
804 			uint32_t offset = ir.increase_bound_by(2);
805 			uint32_t type_ptr_id = offset;
806 			uint32_t var_id = offset + 1;
807 
808 			// Create gl_WorkgroupSize.
809 			uint32_t type_id = build_extended_vector_type(get_uint_type_id(), 3);
810 			SPIRType uint_type_ptr = get<SPIRType>(type_id);
811 			uint_type_ptr.pointer = true;
812 			uint_type_ptr.pointer_depth++;
813 			uint_type_ptr.parent_type = type_id;
814 			uint_type_ptr.storage = StorageClassInput;
815 
816 			auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
817 			ptr_type.self = type_id;
818 			set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
819 			set_decoration(var_id, DecorationBuiltIn, BuiltInWorkgroupSize);
820 			builtin_workgroup_size_id = var_id;
821 			mark_implicit_builtin(StorageClassInput, BuiltInWorkgroupSize, var_id);
822 		}
823 	}
824 
825 	if (needs_swizzle_buffer_def)
826 	{
827 		uint32_t var_id = build_constant_uint_array_pointer();
828 		set_name(var_id, "spvSwizzleConstants");
829 		// This should never match anything.
830 		set_decoration(var_id, DecorationDescriptorSet, kSwizzleBufferBinding);
831 		set_decoration(var_id, DecorationBinding, msl_options.swizzle_buffer_index);
832 		set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary, msl_options.swizzle_buffer_index);
833 		swizzle_buffer_id = var_id;
834 	}
835 
836 	if (!buffers_requiring_array_length.empty())
837 	{
838 		uint32_t var_id = build_constant_uint_array_pointer();
839 		set_name(var_id, "spvBufferSizeConstants");
840 		// This should never match anything.
841 		set_decoration(var_id, DecorationDescriptorSet, kBufferSizeBufferBinding);
842 		set_decoration(var_id, DecorationBinding, msl_options.buffer_size_buffer_index);
843 		set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary, msl_options.buffer_size_buffer_index);
844 		buffer_size_buffer_id = var_id;
845 	}
846 
847 	if (needs_view_mask_buffer())
848 	{
849 		uint32_t var_id = build_constant_uint_array_pointer();
850 		set_name(var_id, "spvViewMask");
851 		// This should never match anything.
852 		set_decoration(var_id, DecorationDescriptorSet, ~(4u));
853 		set_decoration(var_id, DecorationBinding, msl_options.view_mask_buffer_index);
854 		set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary, msl_options.view_mask_buffer_index);
855 		view_mask_buffer_id = var_id;
856 	}
857 
858 	if (!buffers_requiring_dynamic_offset.empty())
859 	{
860 		uint32_t var_id = build_constant_uint_array_pointer();
861 		set_name(var_id, "spvDynamicOffsets");
862 		// This should never match anything.
863 		set_decoration(var_id, DecorationDescriptorSet, ~(5u));
864 		set_decoration(var_id, DecorationBinding, msl_options.dynamic_offsets_buffer_index);
865 		set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary,
866 		                        msl_options.dynamic_offsets_buffer_index);
867 		dynamic_offsets_buffer_id = var_id;
868 	}
869 
870 	// If we're returning a struct from a vertex-like entry point, we must return a position attribute.
871 	bool need_position =
872 			(get_execution_model() == ExecutionModelVertex ||
873 			 get_execution_model() == ExecutionModelTessellationEvaluation) &&
874 			!capture_output_to_buffer && !get_is_rasterization_disabled() &&
875 			!active_output_builtins.get(BuiltInPosition);
876 
877 	if (need_position)
878 	{
879 		// If we can get away with returning void from entry point, we don't need to care.
880 		// If there is at least one other stage output, we need to return [[position]],
881 		// so we need to create one if it doesn't appear in the SPIR-V. Before adding the
882 		// implicit variable, check if it actually exists already, but just has not been used
883 		// or initialized, and if so, mark it as active, and do not create the implicit variable.
884 		bool has_output = false;
885 		ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
886 			if (var.storage == StorageClassOutput && interface_variable_exists_in_entry_point(var.self))
887 			{
888 				has_output = true;
889 
890 				// Check if the var is the Position builtin
891 				if (has_decoration(var.self, DecorationBuiltIn) && get_decoration(var.self, DecorationBuiltIn) == BuiltInPosition)
892 					active_output_builtins.set(BuiltInPosition);
893 
894 				// If the var is a struct, check if any members is the Position builtin
895 				auto &var_type = get_variable_element_type(var);
896 				if (var_type.basetype == SPIRType::Struct)
897 				{
898 					auto mbr_cnt = var_type.member_types.size();
899 					for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
900 					{
901 						auto builtin = BuiltInMax;
902 						bool is_builtin = is_member_builtin(var_type, mbr_idx, &builtin);
903 						if (is_builtin && builtin == BuiltInPosition)
904 							active_output_builtins.set(BuiltInPosition);
905 					}
906 				}
907 			}
908 		});
909 		need_position = has_output && !active_output_builtins.get(BuiltInPosition);
910 	}
911 
912 	if (need_position)
913 	{
914 		uint32_t offset = ir.increase_bound_by(3);
915 		uint32_t type_id = offset;
916 		uint32_t type_ptr_id = offset + 1;
917 		uint32_t var_id = offset + 2;
918 
919 		// Create gl_Position.
920 		SPIRType vec4_type;
921 		vec4_type.basetype = SPIRType::Float;
922 		vec4_type.width = 32;
923 		vec4_type.vecsize = 4;
924 		set<SPIRType>(type_id, vec4_type);
925 
926 		SPIRType vec4_type_ptr;
927 		vec4_type_ptr = vec4_type;
928 		vec4_type_ptr.pointer = true;
929 		vec4_type_ptr.pointer_depth++;
930 		vec4_type_ptr.parent_type = type_id;
931 		vec4_type_ptr.storage = StorageClassOutput;
932 		auto &ptr_type = set<SPIRType>(type_ptr_id, vec4_type_ptr);
933 		ptr_type.self = type_id;
934 
935 		set<SPIRVariable>(var_id, type_ptr_id, StorageClassOutput);
936 		set_decoration(var_id, DecorationBuiltIn, BuiltInPosition);
937 		mark_implicit_builtin(StorageClassOutput, BuiltInPosition, var_id);
938 	}
939 }
940 
941 // Checks if the specified builtin variable (e.g. gl_InstanceIndex) is marked as active.
942 // If not, it marks it as active and forces a recompilation.
943 // This might be used when the optimization of inactive builtins was too optimistic (e.g. when "spvOut" is emitted).
ensure_builtin(spv::StorageClass storage,spv::BuiltIn builtin)944 void CompilerMSL::ensure_builtin(spv::StorageClass storage, spv::BuiltIn builtin)
945 {
946 	Bitset *active_builtins = nullptr;
947 	switch (storage)
948 	{
949 	case StorageClassInput:
950 		active_builtins = &active_input_builtins;
951 		break;
952 
953 	case StorageClassOutput:
954 		active_builtins = &active_output_builtins;
955 		break;
956 
957 	default:
958 		break;
959 	}
960 
961 	// At this point, the specified builtin variable must have already been declared in the entry point.
962 	// If not, mark as active and force recompile.
963 	if (active_builtins != nullptr && !active_builtins->get(builtin))
964 	{
965 		active_builtins->set(builtin);
966 		force_recompile();
967 	}
968 }
969 
mark_implicit_builtin(StorageClass storage,BuiltIn builtin,uint32_t id)970 void CompilerMSL::mark_implicit_builtin(StorageClass storage, BuiltIn builtin, uint32_t id)
971 {
972 	Bitset *active_builtins = nullptr;
973 	switch (storage)
974 	{
975 	case StorageClassInput:
976 		active_builtins = &active_input_builtins;
977 		break;
978 
979 	case StorageClassOutput:
980 		active_builtins = &active_output_builtins;
981 		break;
982 
983 	default:
984 		break;
985 	}
986 
987 	assert(active_builtins != nullptr);
988 	active_builtins->set(builtin);
989 
990 	auto &var = get_entry_point().interface_variables;
991 	if (find(begin(var), end(var), VariableID(id)) == end(var))
992 		var.push_back(id);
993 }
994 
build_constant_uint_array_pointer()995 uint32_t CompilerMSL::build_constant_uint_array_pointer()
996 {
997 	uint32_t offset = ir.increase_bound_by(3);
998 	uint32_t type_ptr_id = offset;
999 	uint32_t type_ptr_ptr_id = offset + 1;
1000 	uint32_t var_id = offset + 2;
1001 
1002 	// Create a buffer to hold extra data, including the swizzle constants.
1003 	SPIRType uint_type_pointer = get_uint_type();
1004 	uint_type_pointer.pointer = true;
1005 	uint_type_pointer.pointer_depth++;
1006 	uint_type_pointer.parent_type = get_uint_type_id();
1007 	uint_type_pointer.storage = StorageClassUniform;
1008 	set<SPIRType>(type_ptr_id, uint_type_pointer);
1009 	set_decoration(type_ptr_id, DecorationArrayStride, 4);
1010 
1011 	SPIRType uint_type_pointer2 = uint_type_pointer;
1012 	uint_type_pointer2.pointer_depth++;
1013 	uint_type_pointer2.parent_type = type_ptr_id;
1014 	set<SPIRType>(type_ptr_ptr_id, uint_type_pointer2);
1015 
1016 	set<SPIRVariable>(var_id, type_ptr_ptr_id, StorageClassUniformConstant);
1017 	return var_id;
1018 }
1019 
create_sampler_address(const char * prefix,MSLSamplerAddress addr)1020 static string create_sampler_address(const char *prefix, MSLSamplerAddress addr)
1021 {
1022 	switch (addr)
1023 	{
1024 	case MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE:
1025 		return join(prefix, "address::clamp_to_edge");
1026 	case MSL_SAMPLER_ADDRESS_CLAMP_TO_ZERO:
1027 		return join(prefix, "address::clamp_to_zero");
1028 	case MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER:
1029 		return join(prefix, "address::clamp_to_border");
1030 	case MSL_SAMPLER_ADDRESS_REPEAT:
1031 		return join(prefix, "address::repeat");
1032 	case MSL_SAMPLER_ADDRESS_MIRRORED_REPEAT:
1033 		return join(prefix, "address::mirrored_repeat");
1034 	default:
1035 		SPIRV_CROSS_THROW("Invalid sampler addressing mode.");
1036 	}
1037 }
1038 
get_stage_in_struct_type()1039 SPIRType &CompilerMSL::get_stage_in_struct_type()
1040 {
1041 	auto &si_var = get<SPIRVariable>(stage_in_var_id);
1042 	return get_variable_data_type(si_var);
1043 }
1044 
get_stage_out_struct_type()1045 SPIRType &CompilerMSL::get_stage_out_struct_type()
1046 {
1047 	auto &so_var = get<SPIRVariable>(stage_out_var_id);
1048 	return get_variable_data_type(so_var);
1049 }
1050 
get_patch_stage_in_struct_type()1051 SPIRType &CompilerMSL::get_patch_stage_in_struct_type()
1052 {
1053 	auto &si_var = get<SPIRVariable>(patch_stage_in_var_id);
1054 	return get_variable_data_type(si_var);
1055 }
1056 
get_patch_stage_out_struct_type()1057 SPIRType &CompilerMSL::get_patch_stage_out_struct_type()
1058 {
1059 	auto &so_var = get<SPIRVariable>(patch_stage_out_var_id);
1060 	return get_variable_data_type(so_var);
1061 }
1062 
get_tess_factor_struct_name()1063 std::string CompilerMSL::get_tess_factor_struct_name()
1064 {
1065 	if (get_entry_point().flags.get(ExecutionModeTriangles))
1066 		return "MTLTriangleTessellationFactorsHalf";
1067 	return "MTLQuadTessellationFactorsHalf";
1068 }
1069 
get_uint_type()1070 SPIRType &CompilerMSL::get_uint_type()
1071 {
1072 	return get<SPIRType>(get_uint_type_id());
1073 }
1074 
get_uint_type_id()1075 uint32_t CompilerMSL::get_uint_type_id()
1076 {
1077 	if (uint_type_id != 0)
1078 		return uint_type_id;
1079 
1080 	uint_type_id = ir.increase_bound_by(1);
1081 
1082 	SPIRType type;
1083 	type.basetype = SPIRType::UInt;
1084 	type.width = 32;
1085 	set<SPIRType>(uint_type_id, type);
1086 	return uint_type_id;
1087 }
1088 
emit_entry_point_declarations()1089 void CompilerMSL::emit_entry_point_declarations()
1090 {
1091 	// FIXME: Get test coverage here ...
1092 	// Constant arrays of non-primitive types (i.e. matrices) won't link properly into Metal libraries
1093 	declare_complex_constant_arrays();
1094 
1095 	// Emit constexpr samplers here.
1096 	for (auto &samp : constexpr_samplers_by_id)
1097 	{
1098 		auto &var = get<SPIRVariable>(samp.first);
1099 		auto &type = get<SPIRType>(var.basetype);
1100 		if (type.basetype == SPIRType::Sampler)
1101 			add_resource_name(samp.first);
1102 
1103 		SmallVector<string> args;
1104 		auto &s = samp.second;
1105 
1106 		if (s.coord != MSL_SAMPLER_COORD_NORMALIZED)
1107 			args.push_back("coord::pixel");
1108 
1109 		if (s.min_filter == s.mag_filter)
1110 		{
1111 			if (s.min_filter != MSL_SAMPLER_FILTER_NEAREST)
1112 				args.push_back("filter::linear");
1113 		}
1114 		else
1115 		{
1116 			if (s.min_filter != MSL_SAMPLER_FILTER_NEAREST)
1117 				args.push_back("min_filter::linear");
1118 			if (s.mag_filter != MSL_SAMPLER_FILTER_NEAREST)
1119 				args.push_back("mag_filter::linear");
1120 		}
1121 
1122 		switch (s.mip_filter)
1123 		{
1124 		case MSL_SAMPLER_MIP_FILTER_NONE:
1125 			// Default
1126 			break;
1127 		case MSL_SAMPLER_MIP_FILTER_NEAREST:
1128 			args.push_back("mip_filter::nearest");
1129 			break;
1130 		case MSL_SAMPLER_MIP_FILTER_LINEAR:
1131 			args.push_back("mip_filter::linear");
1132 			break;
1133 		default:
1134 			SPIRV_CROSS_THROW("Invalid mip filter.");
1135 		}
1136 
1137 		if (s.s_address == s.t_address && s.s_address == s.r_address)
1138 		{
1139 			if (s.s_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
1140 				args.push_back(create_sampler_address("", s.s_address));
1141 		}
1142 		else
1143 		{
1144 			if (s.s_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
1145 				args.push_back(create_sampler_address("s_", s.s_address));
1146 			if (s.t_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
1147 				args.push_back(create_sampler_address("t_", s.t_address));
1148 			if (s.r_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
1149 				args.push_back(create_sampler_address("r_", s.r_address));
1150 		}
1151 
1152 		if (s.compare_enable)
1153 		{
1154 			switch (s.compare_func)
1155 			{
1156 			case MSL_SAMPLER_COMPARE_FUNC_ALWAYS:
1157 				args.push_back("compare_func::always");
1158 				break;
1159 			case MSL_SAMPLER_COMPARE_FUNC_NEVER:
1160 				args.push_back("compare_func::never");
1161 				break;
1162 			case MSL_SAMPLER_COMPARE_FUNC_EQUAL:
1163 				args.push_back("compare_func::equal");
1164 				break;
1165 			case MSL_SAMPLER_COMPARE_FUNC_NOT_EQUAL:
1166 				args.push_back("compare_func::not_equal");
1167 				break;
1168 			case MSL_SAMPLER_COMPARE_FUNC_LESS:
1169 				args.push_back("compare_func::less");
1170 				break;
1171 			case MSL_SAMPLER_COMPARE_FUNC_LESS_EQUAL:
1172 				args.push_back("compare_func::less_equal");
1173 				break;
1174 			case MSL_SAMPLER_COMPARE_FUNC_GREATER:
1175 				args.push_back("compare_func::greater");
1176 				break;
1177 			case MSL_SAMPLER_COMPARE_FUNC_GREATER_EQUAL:
1178 				args.push_back("compare_func::greater_equal");
1179 				break;
1180 			default:
1181 				SPIRV_CROSS_THROW("Invalid sampler compare function.");
1182 			}
1183 		}
1184 
1185 		if (s.s_address == MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER || s.t_address == MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER ||
1186 		    s.r_address == MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER)
1187 		{
1188 			switch (s.border_color)
1189 			{
1190 			case MSL_SAMPLER_BORDER_COLOR_OPAQUE_BLACK:
1191 				args.push_back("border_color::opaque_black");
1192 				break;
1193 			case MSL_SAMPLER_BORDER_COLOR_OPAQUE_WHITE:
1194 				args.push_back("border_color::opaque_white");
1195 				break;
1196 			case MSL_SAMPLER_BORDER_COLOR_TRANSPARENT_BLACK:
1197 				args.push_back("border_color::transparent_black");
1198 				break;
1199 			default:
1200 				SPIRV_CROSS_THROW("Invalid sampler border color.");
1201 			}
1202 		}
1203 
1204 		if (s.anisotropy_enable)
1205 			args.push_back(join("max_anisotropy(", s.max_anisotropy, ")"));
1206 		if (s.lod_clamp_enable)
1207 		{
1208 			args.push_back(join("lod_clamp(", convert_to_string(s.lod_clamp_min, current_locale_radix_character), ", ",
1209 			                    convert_to_string(s.lod_clamp_max, current_locale_radix_character), ")"));
1210 		}
1211 
1212 		// If we would emit no arguments, then omit the parentheses entirely. Otherwise,
1213 		// we'll wind up with a "most vexing parse" situation.
1214 		if (args.empty())
1215 			statement("constexpr sampler ",
1216 			          type.basetype == SPIRType::SampledImage ? to_sampler_expression(samp.first) : to_name(samp.first),
1217 			          ";");
1218 		else
1219 			statement("constexpr sampler ",
1220 			          type.basetype == SPIRType::SampledImage ? to_sampler_expression(samp.first) : to_name(samp.first),
1221 			          "(", merge(args), ");");
1222 	}
1223 
1224 	// Emit dynamic buffers here.
1225 	for (auto &dynamic_buffer : buffers_requiring_dynamic_offset)
1226 	{
1227 		if (!dynamic_buffer.second.second)
1228 		{
1229 			// Could happen if no buffer was used at requested binding point.
1230 			continue;
1231 		}
1232 
1233 		const auto &var = get<SPIRVariable>(dynamic_buffer.second.second);
1234 		uint32_t var_id = var.self;
1235 		const auto &type = get_variable_data_type(var);
1236 		string name = to_name(var.self);
1237 		uint32_t desc_set = get_decoration(var.self, DecorationDescriptorSet);
1238 		uint32_t arg_id = argument_buffer_ids[desc_set];
1239 		uint32_t base_index = dynamic_buffer.second.first;
1240 
1241 		if (!type.array.empty())
1242 		{
1243 			// This is complicated, because we need to support arrays of arrays.
1244 			// And it's even worse if the outermost dimension is a runtime array, because now
1245 			// all this complicated goop has to go into the shader itself. (FIXME)
1246 			if (!type.array[type.array.size() - 1])
1247 				SPIRV_CROSS_THROW("Runtime arrays with dynamic offsets are not supported yet.");
1248 			else
1249 			{
1250 				is_using_builtin_array = true;
1251 				statement(get_argument_address_space(var), " ", type_to_glsl(type), "* ", to_restrict(var_id), name,
1252 				          type_to_array_glsl(type), " =");
1253 
1254 				uint32_t dim = uint32_t(type.array.size());
1255 				uint32_t j = 0;
1256 				for (SmallVector<uint32_t> indices(type.array.size());
1257 				     indices[type.array.size() - 1] < to_array_size_literal(type); j++)
1258 				{
1259 					while (dim > 0)
1260 					{
1261 						begin_scope();
1262 						--dim;
1263 					}
1264 
1265 					string arrays;
1266 					for (uint32_t i = uint32_t(type.array.size()); i; --i)
1267 						arrays += join("[", indices[i - 1], "]");
1268 					statement("(", get_argument_address_space(var), " ", type_to_glsl(type), "* ",
1269 					          to_restrict(var_id, false), ")((", get_argument_address_space(var), " char* ",
1270 					          to_restrict(var_id, false), ")", to_name(arg_id), ".", ensure_valid_name(name, "m"),
1271 					          arrays, " + ", to_name(dynamic_offsets_buffer_id), "[", base_index + j, "]),");
1272 
1273 					while (++indices[dim] >= to_array_size_literal(type, dim) && dim < type.array.size() - 1)
1274 					{
1275 						end_scope(",");
1276 						indices[dim++] = 0;
1277 					}
1278 				}
1279 				end_scope_decl();
1280 				statement_no_indent("");
1281 				is_using_builtin_array = false;
1282 			}
1283 		}
1284 		else
1285 		{
1286 			statement(get_argument_address_space(var), " auto& ", to_restrict(var_id), name, " = *(",
1287 			          get_argument_address_space(var), " ", type_to_glsl(type), "* ", to_restrict(var_id, false), ")((",
1288 			          get_argument_address_space(var), " char* ", to_restrict(var_id, false), ")", to_name(arg_id), ".",
1289 			          ensure_valid_name(name, "m"), " + ", to_name(dynamic_offsets_buffer_id), "[", base_index, "]);");
1290 		}
1291 	}
1292 
1293 	// Emit buffer arrays here.
1294 	for (uint32_t array_id : buffer_arrays)
1295 	{
1296 		const auto &var = get<SPIRVariable>(array_id);
1297 		const auto &type = get_variable_data_type(var);
1298 		const auto &buffer_type = get_variable_element_type(var);
1299 		string name = to_name(array_id);
1300 		statement(get_argument_address_space(var), " ", type_to_glsl(buffer_type), "* ", to_restrict(array_id), name,
1301 		          "[] =");
1302 		begin_scope();
1303 		for (uint32_t i = 0; i < to_array_size_literal(type); ++i)
1304 			statement(name, "_", i, ",");
1305 		end_scope_decl();
1306 		statement_no_indent("");
1307 	}
1308 	// For some reason, without this, we end up emitting the arrays twice.
1309 	buffer_arrays.clear();
1310 
1311 	// Emit disabled fragment outputs.
1312 	std::sort(disabled_frag_outputs.begin(), disabled_frag_outputs.end());
1313 	for (uint32_t var_id : disabled_frag_outputs)
1314 	{
1315 		auto &var = get<SPIRVariable>(var_id);
1316 		add_local_variable_name(var_id);
1317 		statement(variable_decl(var), ";");
1318 		var.deferred_declaration = false;
1319 	}
1320 }
1321 
compile()1322 string CompilerMSL::compile()
1323 {
1324 	replace_illegal_entry_point_names();
1325 	ir.fixup_reserved_names();
1326 
1327 	// Do not deal with GLES-isms like precision, older extensions and such.
1328 	options.vulkan_semantics = true;
1329 	options.es = false;
1330 	options.version = 450;
1331 	backend.null_pointer_literal = "nullptr";
1332 	backend.float_literal_suffix = false;
1333 	backend.uint32_t_literal_suffix = true;
1334 	backend.int16_t_literal_suffix = "";
1335 	backend.uint16_t_literal_suffix = "";
1336 	backend.basic_int_type = "int";
1337 	backend.basic_uint_type = "uint";
1338 	backend.basic_int8_type = "char";
1339 	backend.basic_uint8_type = "uchar";
1340 	backend.basic_int16_type = "short";
1341 	backend.basic_uint16_type = "ushort";
1342 	backend.discard_literal = "discard_fragment()";
1343 	backend.demote_literal = "discard_fragment()";
1344 	backend.boolean_mix_function = "select";
1345 	backend.swizzle_is_function = false;
1346 	backend.shared_is_implied = false;
1347 	backend.use_initializer_list = true;
1348 	backend.use_typed_initializer_list = true;
1349 	backend.native_row_major_matrix = false;
1350 	backend.unsized_array_supported = false;
1351 	backend.can_declare_arrays_inline = false;
1352 	backend.allow_truncated_access_chain = true;
1353 	backend.comparison_image_samples_scalar = true;
1354 	backend.native_pointers = true;
1355 	backend.nonuniform_qualifier = "";
1356 	backend.support_small_type_sampling_result = true;
1357 	backend.supports_empty_struct = true;
1358 
1359 	// Allow Metal to use the array<T> template unless we force it off.
1360 	backend.can_return_array = !msl_options.force_native_arrays;
1361 	backend.array_is_value_type = !msl_options.force_native_arrays;
1362 	// Arrays which are part of buffer objects are never considered to be native arrays.
1363 	backend.buffer_offset_array_is_value_type = false;
1364 	backend.support_pointer_to_pointer = true;
1365 
1366 	capture_output_to_buffer = msl_options.capture_output_to_buffer;
1367 	is_rasterization_disabled = msl_options.disable_rasterization || capture_output_to_buffer;
1368 
1369 	// Initialize array here rather than constructor, MSVC 2013 workaround.
1370 	for (auto &id : next_metal_resource_ids)
1371 		id = 0;
1372 
1373 	fixup_type_alias();
1374 	replace_illegal_names();
1375 	sync_entry_point_aliases_and_names();
1376 
1377 	build_function_control_flow_graphs_and_analyze();
1378 	update_active_builtins();
1379 	analyze_image_and_sampler_usage();
1380 	analyze_sampled_image_usage();
1381 	analyze_interlocked_resource_usage();
1382 	preprocess_op_codes();
1383 	build_implicit_builtins();
1384 
1385 	fixup_image_load_store_access();
1386 
1387 	set_enabled_interface_variables(get_active_interface_variables());
1388 	if (msl_options.force_active_argument_buffer_resources)
1389 		activate_argument_buffer_resources();
1390 
1391 	if (swizzle_buffer_id)
1392 		active_interface_variables.insert(swizzle_buffer_id);
1393 	if (buffer_size_buffer_id)
1394 		active_interface_variables.insert(buffer_size_buffer_id);
1395 	if (view_mask_buffer_id)
1396 		active_interface_variables.insert(view_mask_buffer_id);
1397 	if (dynamic_offsets_buffer_id)
1398 		active_interface_variables.insert(dynamic_offsets_buffer_id);
1399 	if (builtin_layer_id)
1400 		active_interface_variables.insert(builtin_layer_id);
1401 	if (builtin_dispatch_base_id && !msl_options.supports_msl_version(1, 2))
1402 		active_interface_variables.insert(builtin_dispatch_base_id);
1403 	if (builtin_sample_mask_id)
1404 		active_interface_variables.insert(builtin_sample_mask_id);
1405 
1406 	// Create structs to hold input, output and uniform variables.
1407 	// Do output first to ensure out. is declared at top of entry function.
1408 	qual_pos_var_name = "";
1409 	stage_out_var_id = add_interface_block(StorageClassOutput);
1410 	patch_stage_out_var_id = add_interface_block(StorageClassOutput, true);
1411 	stage_in_var_id = add_interface_block(StorageClassInput);
1412 	if (get_execution_model() == ExecutionModelTessellationEvaluation)
1413 		patch_stage_in_var_id = add_interface_block(StorageClassInput, true);
1414 
1415 	if (get_execution_model() == ExecutionModelTessellationControl)
1416 		stage_out_ptr_var_id = add_interface_block_pointer(stage_out_var_id, StorageClassOutput);
1417 	if (is_tessellation_shader())
1418 		stage_in_ptr_var_id = add_interface_block_pointer(stage_in_var_id, StorageClassInput);
1419 
1420 	// Metal vertex functions that define no output must disable rasterization and return void.
1421 	if (!stage_out_var_id)
1422 		is_rasterization_disabled = true;
1423 
1424 	// Convert the use of global variables to recursively-passed function parameters
1425 	localize_global_variables();
1426 	extract_global_variables_from_functions();
1427 
1428 	// Mark any non-stage-in structs to be tightly packed.
1429 	mark_packable_structs();
1430 	reorder_type_alias();
1431 
1432 	// Add fixup hooks required by shader inputs and outputs. This needs to happen before
1433 	// the loop, so the hooks aren't added multiple times.
1434 	fix_up_shader_inputs_outputs();
1435 
1436 	// If we are using argument buffers, we create argument buffer structures for them here.
1437 	// These buffers will be used in the entry point, not the individual resources.
1438 	if (msl_options.argument_buffers)
1439 	{
1440 		if (!msl_options.supports_msl_version(2, 0))
1441 			SPIRV_CROSS_THROW("Argument buffers can only be used with MSL 2.0 and up.");
1442 		analyze_argument_buffers();
1443 	}
1444 
1445 	uint32_t pass_count = 0;
1446 	do
1447 	{
1448 		if (pass_count >= 3)
1449 			SPIRV_CROSS_THROW("Over 3 compilation loops detected. Must be a bug!");
1450 
1451 		reset();
1452 
1453 		// Start bindings at zero.
1454 		next_metal_resource_index_buffer = 0;
1455 		next_metal_resource_index_texture = 0;
1456 		next_metal_resource_index_sampler = 0;
1457 		for (auto &id : next_metal_resource_ids)
1458 			id = 0;
1459 
1460 		// Move constructor for this type is broken on GCC 4.9 ...
1461 		buffer.reset();
1462 
1463 		emit_header();
1464 		emit_custom_templates();
1465 		emit_specialization_constants_and_structs();
1466 		emit_resources();
1467 		emit_custom_functions();
1468 		emit_function(get<SPIRFunction>(ir.default_entry_point), Bitset());
1469 
1470 		pass_count++;
1471 	} while (is_forcing_recompilation());
1472 
1473 	return buffer.str();
1474 }
1475 
1476 // Register the need to output any custom functions.
preprocess_op_codes()1477 void CompilerMSL::preprocess_op_codes()
1478 {
1479 	OpCodePreprocessor preproc(*this);
1480 	traverse_all_reachable_opcodes(get<SPIRFunction>(ir.default_entry_point), preproc);
1481 
1482 	suppress_missing_prototypes = preproc.suppress_missing_prototypes;
1483 
1484 	if (preproc.uses_atomics)
1485 	{
1486 		add_header_line("#include <metal_atomic>");
1487 		add_pragma_line("#pragma clang diagnostic ignored \"-Wunused-variable\"");
1488 	}
1489 
1490 	// Before MSL 2.1 (2.2 for textures), Metal vertex functions that write to
1491 	// resources must disable rasterization and return void.
1492 	if (preproc.uses_resource_write)
1493 		is_rasterization_disabled = true;
1494 
1495 	// Tessellation control shaders are run as compute functions in Metal, and so
1496 	// must capture their output to a buffer.
1497 	if (get_execution_model() == ExecutionModelTessellationControl ||
1498 	    (get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation))
1499 	{
1500 		is_rasterization_disabled = true;
1501 		capture_output_to_buffer = true;
1502 	}
1503 
1504 	if (preproc.needs_subgroup_invocation_id)
1505 		needs_subgroup_invocation_id = true;
1506 	if (preproc.needs_subgroup_size)
1507 		needs_subgroup_size = true;
1508 	// build_implicit_builtins() hasn't run yet, and in fact, this needs to execute
1509 	// before then so that gl_SampleID will get added; so we also need to check if
1510 	// that function would add gl_FragCoord.
1511 	if (preproc.needs_sample_id || msl_options.force_sample_rate_shading ||
1512 	    (is_sample_rate() && (active_input_builtins.get(BuiltInFragCoord) ||
1513 	                          (need_subpass_input && !msl_options.use_framebuffer_fetch_subpasses))))
1514 		needs_sample_id = true;
1515 }
1516 
1517 // Move the Private and Workgroup global variables to the entry function.
1518 // Non-constant variables cannot have global scope in Metal.
localize_global_variables()1519 void CompilerMSL::localize_global_variables()
1520 {
1521 	auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
1522 	auto iter = global_variables.begin();
1523 	while (iter != global_variables.end())
1524 	{
1525 		uint32_t v_id = *iter;
1526 		auto &var = get<SPIRVariable>(v_id);
1527 		if (var.storage == StorageClassPrivate || var.storage == StorageClassWorkgroup)
1528 		{
1529 			if (!variable_is_lut(var))
1530 				entry_func.add_local_variable(v_id);
1531 			iter = global_variables.erase(iter);
1532 		}
1533 		else
1534 			iter++;
1535 	}
1536 }
1537 
1538 // For any global variable accessed directly by a function,
1539 // extract that variable and add it as an argument to that function.
extract_global_variables_from_functions()1540 void CompilerMSL::extract_global_variables_from_functions()
1541 {
1542 	// Uniforms
1543 	unordered_set<uint32_t> global_var_ids;
1544 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1545 		if (var.storage == StorageClassInput || var.storage == StorageClassOutput ||
1546 		    var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant ||
1547 		    var.storage == StorageClassPushConstant || var.storage == StorageClassStorageBuffer)
1548 		{
1549 			global_var_ids.insert(var.self);
1550 		}
1551 	});
1552 
1553 	// Local vars that are declared in the main function and accessed directly by a function
1554 	auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
1555 	for (auto &var : entry_func.local_variables)
1556 		if (get<SPIRVariable>(var).storage != StorageClassFunction)
1557 			global_var_ids.insert(var);
1558 
1559 	std::set<uint32_t> added_arg_ids;
1560 	unordered_set<uint32_t> processed_func_ids;
1561 	extract_global_variables_from_function(ir.default_entry_point, added_arg_ids, global_var_ids, processed_func_ids);
1562 }
1563 
1564 // MSL does not support the use of global variables for shader input content.
1565 // For any global variable accessed directly by the specified function, extract that variable,
1566 // add it as an argument to that function, and the arg to the added_arg_ids collection.
extract_global_variables_from_function(uint32_t func_id,std::set<uint32_t> & added_arg_ids,unordered_set<uint32_t> & global_var_ids,unordered_set<uint32_t> & processed_func_ids)1567 void CompilerMSL::extract_global_variables_from_function(uint32_t func_id, std::set<uint32_t> &added_arg_ids,
1568                                                          unordered_set<uint32_t> &global_var_ids,
1569                                                          unordered_set<uint32_t> &processed_func_ids)
1570 {
1571 	// Avoid processing a function more than once
1572 	if (processed_func_ids.find(func_id) != processed_func_ids.end())
1573 	{
1574 		// Return function global variables
1575 		added_arg_ids = function_global_vars[func_id];
1576 		return;
1577 	}
1578 
1579 	processed_func_ids.insert(func_id);
1580 
1581 	auto &func = get<SPIRFunction>(func_id);
1582 
1583 	// Recursively establish global args added to functions on which we depend.
1584 	for (auto block : func.blocks)
1585 	{
1586 		auto &b = get<SPIRBlock>(block);
1587 		for (auto &i : b.ops)
1588 		{
1589 			auto ops = stream(i);
1590 			auto op = static_cast<Op>(i.op);
1591 
1592 			switch (op)
1593 			{
1594 			case OpLoad:
1595 			case OpInBoundsAccessChain:
1596 			case OpAccessChain:
1597 			case OpPtrAccessChain:
1598 			case OpArrayLength:
1599 			{
1600 				uint32_t base_id = ops[2];
1601 				if (global_var_ids.find(base_id) != global_var_ids.end())
1602 					added_arg_ids.insert(base_id);
1603 
1604 				// Use Metal's native frame-buffer fetch API for subpass inputs.
1605 				auto &type = get<SPIRType>(ops[0]);
1606 				if (type.basetype == SPIRType::Image && type.image.dim == DimSubpassData &&
1607 				    (!msl_options.use_framebuffer_fetch_subpasses))
1608 				{
1609 					// Implicitly reads gl_FragCoord.
1610 					assert(builtin_frag_coord_id != 0);
1611 					added_arg_ids.insert(builtin_frag_coord_id);
1612 					if (msl_options.multiview)
1613 					{
1614 						// Implicitly reads gl_ViewIndex.
1615 						assert(builtin_view_idx_id != 0);
1616 						added_arg_ids.insert(builtin_view_idx_id);
1617 					}
1618 					else if (msl_options.arrayed_subpass_input)
1619 					{
1620 						// Implicitly reads gl_Layer.
1621 						assert(builtin_layer_id != 0);
1622 						added_arg_ids.insert(builtin_layer_id);
1623 					}
1624 				}
1625 
1626 				break;
1627 			}
1628 
1629 			case OpFunctionCall:
1630 			{
1631 				// First see if any of the function call args are globals
1632 				for (uint32_t arg_idx = 3; arg_idx < i.length; arg_idx++)
1633 				{
1634 					uint32_t arg_id = ops[arg_idx];
1635 					if (global_var_ids.find(arg_id) != global_var_ids.end())
1636 						added_arg_ids.insert(arg_id);
1637 				}
1638 
1639 				// Then recurse into the function itself to extract globals used internally in the function
1640 				uint32_t inner_func_id = ops[2];
1641 				std::set<uint32_t> inner_func_args;
1642 				extract_global_variables_from_function(inner_func_id, inner_func_args, global_var_ids,
1643 				                                       processed_func_ids);
1644 				added_arg_ids.insert(inner_func_args.begin(), inner_func_args.end());
1645 				break;
1646 			}
1647 
1648 			case OpStore:
1649 			{
1650 				uint32_t base_id = ops[0];
1651 				if (global_var_ids.find(base_id) != global_var_ids.end())
1652 					added_arg_ids.insert(base_id);
1653 
1654 				uint32_t rvalue_id = ops[1];
1655 				if (global_var_ids.find(rvalue_id) != global_var_ids.end())
1656 					added_arg_ids.insert(rvalue_id);
1657 
1658 				break;
1659 			}
1660 
1661 			case OpSelect:
1662 			{
1663 				uint32_t base_id = ops[3];
1664 				if (global_var_ids.find(base_id) != global_var_ids.end())
1665 					added_arg_ids.insert(base_id);
1666 				base_id = ops[4];
1667 				if (global_var_ids.find(base_id) != global_var_ids.end())
1668 					added_arg_ids.insert(base_id);
1669 				break;
1670 			}
1671 
1672 			// Emulate texture2D atomic operations
1673 			case OpImageTexelPointer:
1674 			{
1675 				// When using the pointer, we need to know which variable it is actually loaded from.
1676 				uint32_t base_id = ops[2];
1677 				auto *var = maybe_get_backing_variable(base_id);
1678 				if (var && atomic_image_vars.count(var->self))
1679 				{
1680 					if (global_var_ids.find(base_id) != global_var_ids.end())
1681 						added_arg_ids.insert(base_id);
1682 				}
1683 				break;
1684 			}
1685 
1686 			case OpExtInst:
1687 			{
1688 				uint32_t extension_set = ops[2];
1689 				if (get<SPIRExtension>(extension_set).ext == SPIRExtension::GLSL)
1690 				{
1691 					auto op_450 = static_cast<GLSLstd450>(ops[3]);
1692 					switch (op_450)
1693 					{
1694 					case GLSLstd450InterpolateAtCentroid:
1695 					case GLSLstd450InterpolateAtSample:
1696 					case GLSLstd450InterpolateAtOffset:
1697 					{
1698 						// For these, we really need the stage-in block. It is theoretically possible to pass the
1699 						// interpolant object, but a) doing so would require us to create an entirely new variable
1700 						// with Interpolant type, and b) if we have a struct or array, handling all the members and
1701 						// elements could get unwieldy fast.
1702 						added_arg_ids.insert(stage_in_var_id);
1703 						break;
1704 					}
1705 					default:
1706 						break;
1707 					}
1708 				}
1709 				break;
1710 			}
1711 
1712 			case OpGroupNonUniformInverseBallot:
1713 			{
1714 				added_arg_ids.insert(builtin_subgroup_invocation_id_id);
1715 				break;
1716 			}
1717 
1718 			case OpGroupNonUniformBallotFindLSB:
1719 			case OpGroupNonUniformBallotFindMSB:
1720 			{
1721 				added_arg_ids.insert(builtin_subgroup_size_id);
1722 				break;
1723 			}
1724 
1725 			case OpGroupNonUniformBallotBitCount:
1726 			{
1727 				auto operation = static_cast<GroupOperation>(ops[3]);
1728 				switch (operation)
1729 				{
1730 				case GroupOperationReduce:
1731 					added_arg_ids.insert(builtin_subgroup_size_id);
1732 					break;
1733 				case GroupOperationInclusiveScan:
1734 				case GroupOperationExclusiveScan:
1735 					added_arg_ids.insert(builtin_subgroup_invocation_id_id);
1736 					break;
1737 				default:
1738 					break;
1739 				}
1740 				break;
1741 			}
1742 
1743 			default:
1744 				break;
1745 			}
1746 
1747 			// TODO: Add all other operations which can affect memory.
1748 			// We should consider a more unified system here to reduce boiler-plate.
1749 			// This kind of analysis is done in several places ...
1750 		}
1751 	}
1752 
1753 	function_global_vars[func_id] = added_arg_ids;
1754 
1755 	// Add the global variables as arguments to the function
1756 	if (func_id != ir.default_entry_point)
1757 	{
1758 		bool control_point_added_in = false;
1759 		bool control_point_added_out = false;
1760 		bool patch_added_in = false;
1761 		bool patch_added_out = false;
1762 
1763 		for (uint32_t arg_id : added_arg_ids)
1764 		{
1765 			auto &var = get<SPIRVariable>(arg_id);
1766 			uint32_t type_id = var.basetype;
1767 			auto *p_type = &get<SPIRType>(type_id);
1768 			BuiltIn bi_type = BuiltIn(get_decoration(arg_id, DecorationBuiltIn));
1769 
1770 			bool is_patch = has_decoration(arg_id, DecorationPatch) || is_patch_block(*p_type);
1771 			bool is_block = has_decoration(p_type->self, DecorationBlock);
1772 			bool is_control_point_storage =
1773 					!is_patch &&
1774 					((is_tessellation_shader() && var.storage == StorageClassInput) ||
1775 					 (get_execution_model() == ExecutionModelTessellationControl && var.storage == StorageClassOutput));
1776 			bool is_patch_block_storage = is_patch && is_block && var.storage == StorageClassOutput;
1777 			bool is_builtin = is_builtin_variable(var);
1778 			bool variable_is_stage_io =
1779 					!is_builtin || bi_type == BuiltInPosition || bi_type == BuiltInPointSize ||
1780 					bi_type == BuiltInClipDistance || bi_type == BuiltInCullDistance ||
1781 					p_type->basetype == SPIRType::Struct;
1782 			bool is_redirected_to_global_stage_io = (is_control_point_storage || is_patch_block_storage) &&
1783 			                                        variable_is_stage_io;
1784 
1785 			// If output is masked it is not considered part of the global stage IO interface.
1786 			if (is_redirected_to_global_stage_io && var.storage == StorageClassOutput)
1787 				is_redirected_to_global_stage_io = !is_stage_output_variable_masked(var);
1788 
1789 			if (is_redirected_to_global_stage_io)
1790 			{
1791 				// Tessellation control shaders see inputs and per-vertex outputs as arrays.
1792 				// Similarly, tessellation evaluation shaders see per-vertex inputs as arrays.
1793 				// We collected them into a structure; we must pass the array of this
1794 				// structure to the function.
1795 				std::string name;
1796 				if (is_patch)
1797 					name = var.storage == StorageClassInput ? patch_stage_in_var_name : patch_stage_out_var_name;
1798 				else
1799 					name = var.storage == StorageClassInput ? "gl_in" : "gl_out";
1800 
1801 				if (var.storage == StorageClassOutput && has_decoration(p_type->self, DecorationBlock))
1802 				{
1803 					// If we're redirecting a block, we might still need to access the original block
1804 					// variable if we're masking some members.
1805 					for (uint32_t mbr_idx = 0; mbr_idx < uint32_t(p_type->member_types.size()); mbr_idx++)
1806 					{
1807 						if (is_stage_output_block_member_masked(var, mbr_idx, true))
1808 						{
1809 							func.add_parameter(var.basetype, var.self, true);
1810 							break;
1811 						}
1812 					}
1813 				}
1814 
1815 				// Tessellation control shaders see inputs and per-vertex outputs as arrays.
1816 				// Similarly, tessellation evaluation shaders see per-vertex inputs as arrays.
1817 				// We collected them into a structure; we must pass the array of this
1818 				// structure to the function.
1819 				if (var.storage == StorageClassInput)
1820 				{
1821 					auto &added_in = is_patch ? patch_added_in : control_point_added_in;
1822 					if (added_in)
1823 						continue;
1824 					arg_id = is_patch ? patch_stage_in_var_id : stage_in_ptr_var_id;
1825 					added_in = true;
1826 				}
1827 				else if (var.storage == StorageClassOutput)
1828 				{
1829 					auto &added_out = is_patch ? patch_added_out : control_point_added_out;
1830 					if (added_out)
1831 						continue;
1832 					arg_id = is_patch ? patch_stage_out_var_id : stage_out_ptr_var_id;
1833 					added_out = true;
1834 				}
1835 
1836 				type_id = get<SPIRVariable>(arg_id).basetype;
1837 				uint32_t next_id = ir.increase_bound_by(1);
1838 				func.add_parameter(type_id, next_id, true);
1839 				set<SPIRVariable>(next_id, type_id, StorageClassFunction, 0, arg_id);
1840 
1841 				set_name(next_id, name);
1842 			}
1843 			else if (is_builtin && has_decoration(p_type->self, DecorationBlock))
1844 			{
1845 				// Get the pointee type
1846 				type_id = get_pointee_type_id(type_id);
1847 				p_type = &get<SPIRType>(type_id);
1848 
1849 				uint32_t mbr_idx = 0;
1850 				for (auto &mbr_type_id : p_type->member_types)
1851 				{
1852 					BuiltIn builtin = BuiltInMax;
1853 					is_builtin = is_member_builtin(*p_type, mbr_idx, &builtin);
1854 					if (is_builtin && has_active_builtin(builtin, var.storage))
1855 					{
1856 						// Add a arg variable with the same type and decorations as the member
1857 						uint32_t next_ids = ir.increase_bound_by(2);
1858 						uint32_t ptr_type_id = next_ids + 0;
1859 						uint32_t var_id = next_ids + 1;
1860 
1861 						// Make sure we have an actual pointer type,
1862 						// so that we will get the appropriate address space when declaring these builtins.
1863 						auto &ptr = set<SPIRType>(ptr_type_id, get<SPIRType>(mbr_type_id));
1864 						ptr.self = mbr_type_id;
1865 						ptr.storage = var.storage;
1866 						ptr.pointer = true;
1867 						ptr.pointer_depth++;
1868 						ptr.parent_type = mbr_type_id;
1869 
1870 						func.add_parameter(mbr_type_id, var_id, true);
1871 						set<SPIRVariable>(var_id, ptr_type_id, StorageClassFunction);
1872 						ir.meta[var_id].decoration = ir.meta[type_id].members[mbr_idx];
1873 					}
1874 					mbr_idx++;
1875 				}
1876 			}
1877 			else
1878 			{
1879 				uint32_t next_id = ir.increase_bound_by(1);
1880 				func.add_parameter(type_id, next_id, true);
1881 				set<SPIRVariable>(next_id, type_id, StorageClassFunction, 0, arg_id);
1882 
1883 				// Ensure the existing variable has a valid name and the new variable has all the same meta info
1884 				set_name(arg_id, ensure_valid_name(to_name(arg_id), "v"));
1885 				ir.meta[next_id] = ir.meta[arg_id];
1886 			}
1887 		}
1888 	}
1889 }
1890 
1891 // For all variables that are some form of non-input-output interface block, mark that all the structs
1892 // that are recursively contained within the type referenced by that variable should be packed tightly.
mark_packable_structs()1893 void CompilerMSL::mark_packable_structs()
1894 {
1895 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1896 		if (var.storage != StorageClassFunction && !is_hidden_variable(var))
1897 		{
1898 			auto &type = this->get<SPIRType>(var.basetype);
1899 			if (type.pointer &&
1900 			    (type.storage == StorageClassUniform || type.storage == StorageClassUniformConstant ||
1901 			     type.storage == StorageClassPushConstant || type.storage == StorageClassStorageBuffer) &&
1902 			    (has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock)))
1903 				mark_as_packable(type);
1904 		}
1905 	});
1906 }
1907 
1908 // If the specified type is a struct, it and any nested structs
1909 // are marked as packable with the SPIRVCrossDecorationBufferBlockRepacked decoration,
mark_as_packable(SPIRType & type)1910 void CompilerMSL::mark_as_packable(SPIRType &type)
1911 {
1912 	// If this is not the base type (eg. it's a pointer or array), tunnel down
1913 	if (type.parent_type)
1914 	{
1915 		mark_as_packable(get<SPIRType>(type.parent_type));
1916 		return;
1917 	}
1918 
1919 	if (type.basetype == SPIRType::Struct)
1920 	{
1921 		set_extended_decoration(type.self, SPIRVCrossDecorationBufferBlockRepacked);
1922 
1923 		// Recurse
1924 		uint32_t mbr_cnt = uint32_t(type.member_types.size());
1925 		for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
1926 		{
1927 			uint32_t mbr_type_id = type.member_types[mbr_idx];
1928 			auto &mbr_type = get<SPIRType>(mbr_type_id);
1929 			mark_as_packable(mbr_type);
1930 			if (mbr_type.type_alias)
1931 			{
1932 				auto &mbr_type_alias = get<SPIRType>(mbr_type.type_alias);
1933 				mark_as_packable(mbr_type_alias);
1934 			}
1935 		}
1936 	}
1937 }
1938 
1939 // If a shader input exists at the location, it is marked as being used by this shader
mark_location_as_used_by_shader(uint32_t location,const SPIRType & type,StorageClass storage,bool fallback)1940 void CompilerMSL::mark_location_as_used_by_shader(uint32_t location, const SPIRType &type,
1941                                                   StorageClass storage, bool fallback)
1942 {
1943 	if (storage != StorageClassInput)
1944 		return;
1945 
1946 	uint32_t count = type_to_location_count(type);
1947 	for (uint32_t i = 0; i < count; i++)
1948 	{
1949 		location_inputs_in_use.insert(location + i);
1950 		if (fallback)
1951 			location_inputs_in_use_fallback.insert(location + i);
1952 	}
1953 }
1954 
get_target_components_for_fragment_location(uint32_t location) const1955 uint32_t CompilerMSL::get_target_components_for_fragment_location(uint32_t location) const
1956 {
1957 	auto itr = fragment_output_components.find(location);
1958 	if (itr == end(fragment_output_components))
1959 		return 4;
1960 	else
1961 		return itr->second;
1962 }
1963 
build_extended_vector_type(uint32_t type_id,uint32_t components,SPIRType::BaseType basetype)1964 uint32_t CompilerMSL::build_extended_vector_type(uint32_t type_id, uint32_t components, SPIRType::BaseType basetype)
1965 {
1966 	uint32_t new_type_id = ir.increase_bound_by(1);
1967 	auto &old_type = get<SPIRType>(type_id);
1968 	auto *type = &set<SPIRType>(new_type_id, old_type);
1969 	type->vecsize = components;
1970 	if (basetype != SPIRType::Unknown)
1971 		type->basetype = basetype;
1972 	type->self = new_type_id;
1973 	type->parent_type = type_id;
1974 	type->array.clear();
1975 	type->array_size_literal.clear();
1976 	type->pointer = false;
1977 
1978 	if (is_array(old_type))
1979 	{
1980 		uint32_t array_type_id = ir.increase_bound_by(1);
1981 		type = &set<SPIRType>(array_type_id, *type);
1982 		type->parent_type = new_type_id;
1983 		type->array = old_type.array;
1984 		type->array_size_literal = old_type.array_size_literal;
1985 		new_type_id = array_type_id;
1986 	}
1987 
1988 	if (old_type.pointer)
1989 	{
1990 		uint32_t ptr_type_id = ir.increase_bound_by(1);
1991 		type = &set<SPIRType>(ptr_type_id, *type);
1992 		type->self = new_type_id;
1993 		type->parent_type = new_type_id;
1994 		type->storage = old_type.storage;
1995 		type->pointer = true;
1996 		type->pointer_depth++;
1997 		new_type_id = ptr_type_id;
1998 	}
1999 
2000 	return new_type_id;
2001 }
2002 
build_msl_interpolant_type(uint32_t type_id,bool is_noperspective)2003 uint32_t CompilerMSL::build_msl_interpolant_type(uint32_t type_id, bool is_noperspective)
2004 {
2005 	uint32_t new_type_id = ir.increase_bound_by(1);
2006 	SPIRType &type = set<SPIRType>(new_type_id, get<SPIRType>(type_id));
2007 	type.basetype = SPIRType::Interpolant;
2008 	type.parent_type = type_id;
2009 	// In Metal, the pull-model interpolant type encodes perspective-vs-no-perspective in the type itself.
2010 	// Add this decoration so we know which argument to pass to the template.
2011 	if (is_noperspective)
2012 		set_decoration(new_type_id, DecorationNoPerspective);
2013 	return new_type_id;
2014 }
2015 
add_component_variable_to_interface_block(spv::StorageClass storage,const std::string & ib_var_ref,SPIRVariable & var,const SPIRType & type,InterfaceBlockMeta & meta)2016 bool CompilerMSL::add_component_variable_to_interface_block(spv::StorageClass storage, const std::string &ib_var_ref,
2017                                                             SPIRVariable &var,
2018                                                             const SPIRType &type,
2019                                                             InterfaceBlockMeta &meta)
2020 {
2021 	// Deal with Component decorations.
2022 	const InterfaceBlockMeta::LocationMeta *location_meta = nullptr;
2023 	uint32_t location = ~0u;
2024 	if (has_decoration(var.self, DecorationLocation))
2025 	{
2026 		location = get_decoration(var.self, DecorationLocation);
2027 		auto location_meta_itr = meta.location_meta.find(location);
2028 		if (location_meta_itr != end(meta.location_meta))
2029 			location_meta = &location_meta_itr->second;
2030 	}
2031 
2032 	// Check if we need to pad fragment output to match a certain number of components.
2033 	if (location_meta)
2034 	{
2035 		bool pad_fragment_output = has_decoration(var.self, DecorationLocation) &&
2036 		                           msl_options.pad_fragment_output_components &&
2037 		                           get_entry_point().model == ExecutionModelFragment && storage == StorageClassOutput;
2038 
2039 		auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2040 		uint32_t start_component = get_decoration(var.self, DecorationComponent);
2041 		uint32_t type_components = type.vecsize;
2042 		uint32_t num_components = location_meta->num_components;
2043 
2044 		if (pad_fragment_output)
2045 		{
2046 			uint32_t locn = get_decoration(var.self, DecorationLocation);
2047 			num_components = std::max(num_components, get_target_components_for_fragment_location(locn));
2048 		}
2049 
2050 		// We have already declared an IO block member as m_location_N.
2051 		// Just emit an early-declared variable and fixup as needed.
2052 		// Arrays need to be unrolled here since each location might need a different number of components.
2053 		entry_func.add_local_variable(var.self);
2054 		vars_needing_early_declaration.push_back(var.self);
2055 
2056 		if (var.storage == StorageClassInput)
2057 		{
2058 			entry_func.fixup_hooks_in.push_back([=, &type, &var]() {
2059 				if (!type.array.empty())
2060 				{
2061 					uint32_t array_size = to_array_size_literal(type);
2062 					for (uint32_t loc_off = 0; loc_off < array_size; loc_off++)
2063 					{
2064 						statement(to_name(var.self), "[", loc_off, "]", " = ", ib_var_ref,
2065 						          ".m_location_", location + loc_off,
2066 						          vector_swizzle(type_components, start_component), ";");
2067 					}
2068 				}
2069 				else
2070 				{
2071 					statement(to_name(var.self), " = ", ib_var_ref, ".m_location_", location,
2072 					          vector_swizzle(type_components, start_component), ";");
2073 				}
2074 			});
2075 		}
2076 		else
2077 		{
2078 			entry_func.fixup_hooks_out.push_back([=, &type, &var]() {
2079 				if (!type.array.empty())
2080 				{
2081 					uint32_t array_size = to_array_size_literal(type);
2082 					for (uint32_t loc_off = 0; loc_off < array_size; loc_off++)
2083 					{
2084 						statement(ib_var_ref, ".m_location_", location + loc_off,
2085 						          vector_swizzle(type_components, start_component), " = ",
2086 						          to_name(var.self), "[", loc_off, "];");
2087 					}
2088 				}
2089 				else
2090 				{
2091 					statement(ib_var_ref, ".m_location_", location,
2092 					          vector_swizzle(type_components, start_component), " = ", to_name(var.self), ";");
2093 				}
2094 			});
2095 		}
2096 		return true;
2097 	}
2098 	else
2099 		return false;
2100 }
2101 
add_plain_variable_to_interface_block(StorageClass storage,const string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var,InterfaceBlockMeta & meta)2102 void CompilerMSL::add_plain_variable_to_interface_block(StorageClass storage, const string &ib_var_ref,
2103                                                         SPIRType &ib_type, SPIRVariable &var, InterfaceBlockMeta &meta)
2104 {
2105 	bool is_builtin = is_builtin_variable(var);
2106 	BuiltIn builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
2107 	bool is_flat = has_decoration(var.self, DecorationFlat);
2108 	bool is_noperspective = has_decoration(var.self, DecorationNoPerspective);
2109 	bool is_centroid = has_decoration(var.self, DecorationCentroid);
2110 	bool is_sample = has_decoration(var.self, DecorationSample);
2111 
2112 	// Add a reference to the variable type to the interface struct.
2113 	uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
2114 	uint32_t type_id = ensure_correct_builtin_type(var.basetype, builtin);
2115 	var.basetype = type_id;
2116 
2117 	type_id = get_pointee_type_id(var.basetype);
2118 	if (meta.strip_array && is_array(get<SPIRType>(type_id)))
2119 		type_id = get<SPIRType>(type_id).parent_type;
2120 	auto &type = get<SPIRType>(type_id);
2121 	uint32_t target_components = 0;
2122 	uint32_t type_components = type.vecsize;
2123 
2124 	bool padded_output = false;
2125 	bool padded_input = false;
2126 	uint32_t start_component = 0;
2127 
2128 	auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2129 
2130 	if (add_component_variable_to_interface_block(storage, ib_var_ref, var, type, meta))
2131 		return;
2132 
2133 	bool pad_fragment_output = has_decoration(var.self, DecorationLocation) &&
2134 	                           msl_options.pad_fragment_output_components &&
2135 	                           get_entry_point().model == ExecutionModelFragment && storage == StorageClassOutput;
2136 
2137 	if (pad_fragment_output)
2138 	{
2139 		uint32_t locn = get_decoration(var.self, DecorationLocation);
2140 		target_components = get_target_components_for_fragment_location(locn);
2141 		if (type_components < target_components)
2142 		{
2143 			// Make a new type here.
2144 			type_id = build_extended_vector_type(type_id, target_components);
2145 			padded_output = true;
2146 		}
2147 	}
2148 
2149 	if (storage == StorageClassInput && pull_model_inputs.count(var.self))
2150 		ib_type.member_types.push_back(build_msl_interpolant_type(type_id, is_noperspective));
2151 	else
2152 		ib_type.member_types.push_back(type_id);
2153 
2154 	// Give the member a name
2155 	string mbr_name = ensure_valid_name(to_expression(var.self), "m");
2156 	set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
2157 
2158 	// Update the original variable reference to include the structure reference
2159 	string qual_var_name = ib_var_ref + "." + mbr_name;
2160 	// If using pull-model interpolation, need to add a call to the correct interpolation method.
2161 	if (storage == StorageClassInput && pull_model_inputs.count(var.self))
2162 	{
2163 		if (is_centroid)
2164 			qual_var_name += ".interpolate_at_centroid()";
2165 		else if (is_sample)
2166 			qual_var_name += join(".interpolate_at_sample(", to_expression(builtin_sample_id_id), ")");
2167 		else
2168 			qual_var_name += ".interpolate_at_center()";
2169 	}
2170 
2171 	if (padded_output || padded_input)
2172 	{
2173 		entry_func.add_local_variable(var.self);
2174 		vars_needing_early_declaration.push_back(var.self);
2175 
2176 		if (padded_output)
2177 		{
2178 			entry_func.fixup_hooks_out.push_back([=, &var]() {
2179 				statement(qual_var_name, vector_swizzle(type_components, start_component), " = ", to_name(var.self),
2180 				          ";");
2181 			});
2182 		}
2183 		else
2184 		{
2185 			entry_func.fixup_hooks_in.push_back([=, &var]() {
2186 				statement(to_name(var.self), " = ", qual_var_name, vector_swizzle(type_components, start_component),
2187 				          ";");
2188 			});
2189 		}
2190 	}
2191 	else if (!meta.strip_array)
2192 		ir.meta[var.self].decoration.qualified_alias = qual_var_name;
2193 
2194 	if (var.storage == StorageClassOutput && var.initializer != ID(0))
2195 	{
2196 		if (padded_output || padded_input)
2197 		{
2198 			entry_func.fixup_hooks_in.push_back(
2199 			    [=, &var]() { statement(to_name(var.self), " = ", to_expression(var.initializer), ";"); });
2200 		}
2201 		else
2202 		{
2203 			if (meta.strip_array)
2204 			{
2205 				entry_func.fixup_hooks_in.push_back([=, &var]() {
2206 					uint32_t index = get_extended_decoration(var.self, SPIRVCrossDecorationInterfaceMemberIndex);
2207 					auto invocation = to_tesc_invocation_id();
2208 					statement(to_expression(stage_out_ptr_var_id), "[",
2209 					          invocation, "].",
2210 					          to_member_name(ib_type, index), " = ", to_expression(var.initializer), "[",
2211 					          invocation, "];");
2212 				});
2213 			}
2214 			else
2215 			{
2216 				entry_func.fixup_hooks_in.push_back([=, &var]() {
2217 					statement(qual_var_name, " = ", to_expression(var.initializer), ";");
2218 				});
2219 			}
2220 		}
2221 	}
2222 
2223 	// Copy the variable location from the original variable to the member
2224 	if (get_decoration_bitset(var.self).get(DecorationLocation))
2225 	{
2226 		uint32_t locn = get_decoration(var.self, DecorationLocation);
2227 		if (storage == StorageClassInput)
2228 		{
2229 			type_id = ensure_correct_input_type(var.basetype, locn, 0, meta.strip_array);
2230 			var.basetype = type_id;
2231 
2232 			type_id = get_pointee_type_id(type_id);
2233 			if (meta.strip_array && is_array(get<SPIRType>(type_id)))
2234 				type_id = get<SPIRType>(type_id).parent_type;
2235 			if (pull_model_inputs.count(var.self))
2236 				ib_type.member_types[ib_mbr_idx] = build_msl_interpolant_type(type_id, is_noperspective);
2237 			else
2238 				ib_type.member_types[ib_mbr_idx] = type_id;
2239 		}
2240 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2241 		mark_location_as_used_by_shader(locn, get<SPIRType>(type_id), storage);
2242 	}
2243 	else if (is_builtin && is_tessellation_shader() && inputs_by_builtin.count(builtin))
2244 	{
2245 		uint32_t locn = inputs_by_builtin[builtin].location;
2246 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2247 		mark_location_as_used_by_shader(locn, type, storage);
2248 	}
2249 
2250 	if (get_decoration_bitset(var.self).get(DecorationComponent))
2251 	{
2252 		uint32_t component = get_decoration(var.self, DecorationComponent);
2253 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationComponent, component);
2254 	}
2255 
2256 	if (get_decoration_bitset(var.self).get(DecorationIndex))
2257 	{
2258 		uint32_t index = get_decoration(var.self, DecorationIndex);
2259 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationIndex, index);
2260 	}
2261 
2262 	// Mark the member as builtin if needed
2263 	if (is_builtin)
2264 	{
2265 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, builtin);
2266 		if (builtin == BuiltInPosition && storage == StorageClassOutput)
2267 			qual_pos_var_name = qual_var_name;
2268 	}
2269 
2270 	// Copy interpolation decorations if needed
2271 	if (storage != StorageClassInput || !pull_model_inputs.count(var.self))
2272 	{
2273 		if (is_flat)
2274 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat);
2275 		if (is_noperspective)
2276 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective);
2277 		if (is_centroid)
2278 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid);
2279 		if (is_sample)
2280 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample);
2281 	}
2282 
2283 	set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceOrigID, var.self);
2284 }
2285 
add_composite_variable_to_interface_block(StorageClass storage,const string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var,InterfaceBlockMeta & meta)2286 void CompilerMSL::add_composite_variable_to_interface_block(StorageClass storage, const string &ib_var_ref,
2287                                                             SPIRType &ib_type, SPIRVariable &var,
2288                                                             InterfaceBlockMeta &meta)
2289 {
2290 	auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2291 	auto &var_type = meta.strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
2292 	uint32_t elem_cnt = 0;
2293 
2294 	if (add_component_variable_to_interface_block(storage, ib_var_ref, var, var_type, meta))
2295 		return;
2296 
2297 	if (is_matrix(var_type))
2298 	{
2299 		if (is_array(var_type))
2300 			SPIRV_CROSS_THROW("MSL cannot emit arrays-of-matrices in input and output variables.");
2301 
2302 		elem_cnt = var_type.columns;
2303 	}
2304 	else if (is_array(var_type))
2305 	{
2306 		if (var_type.array.size() != 1)
2307 			SPIRV_CROSS_THROW("MSL cannot emit arrays-of-arrays in input and output variables.");
2308 
2309 		elem_cnt = to_array_size_literal(var_type);
2310 	}
2311 
2312 	bool is_builtin = is_builtin_variable(var);
2313 	BuiltIn builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
2314 	bool is_flat = has_decoration(var.self, DecorationFlat);
2315 	bool is_noperspective = has_decoration(var.self, DecorationNoPerspective);
2316 	bool is_centroid = has_decoration(var.self, DecorationCentroid);
2317 	bool is_sample = has_decoration(var.self, DecorationSample);
2318 
2319 	auto *usable_type = &var_type;
2320 	if (usable_type->pointer)
2321 		usable_type = &get<SPIRType>(usable_type->parent_type);
2322 	while (is_array(*usable_type) || is_matrix(*usable_type))
2323 		usable_type = &get<SPIRType>(usable_type->parent_type);
2324 
2325 	// If a builtin, force it to have the proper name.
2326 	if (is_builtin)
2327 		set_name(var.self, builtin_to_glsl(builtin, StorageClassFunction));
2328 
2329 	bool flatten_from_ib_var = false;
2330 	string flatten_from_ib_mbr_name;
2331 
2332 	if (storage == StorageClassOutput && is_builtin && builtin == BuiltInClipDistance)
2333 	{
2334 		// Also declare [[clip_distance]] attribute here.
2335 		uint32_t clip_array_mbr_idx = uint32_t(ib_type.member_types.size());
2336 		ib_type.member_types.push_back(get_variable_data_type_id(var));
2337 		set_member_decoration(ib_type.self, clip_array_mbr_idx, DecorationBuiltIn, BuiltInClipDistance);
2338 
2339 		flatten_from_ib_mbr_name = builtin_to_glsl(BuiltInClipDistance, StorageClassOutput);
2340 		set_member_name(ib_type.self, clip_array_mbr_idx, flatten_from_ib_mbr_name);
2341 
2342 		// When we flatten, we flatten directly from the "out" struct,
2343 		// not from a function variable.
2344 		flatten_from_ib_var = true;
2345 
2346 		if (!msl_options.enable_clip_distance_user_varying)
2347 			return;
2348 	}
2349 	else if (!meta.strip_array)
2350 	{
2351 		// Only flatten/unflatten IO composites for non-tessellation cases where arrays are not stripped.
2352 		entry_func.add_local_variable(var.self);
2353 		// We need to declare the variable early and at entry-point scope.
2354 		vars_needing_early_declaration.push_back(var.self);
2355 	}
2356 
2357 	for (uint32_t i = 0; i < elem_cnt; i++)
2358 	{
2359 		// Add a reference to the variable type to the interface struct.
2360 		uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
2361 
2362 		uint32_t target_components = 0;
2363 		bool padded_output = false;
2364 		uint32_t type_id = usable_type->self;
2365 
2366 		// Check if we need to pad fragment output to match a certain number of components.
2367 		if (get_decoration_bitset(var.self).get(DecorationLocation) && msl_options.pad_fragment_output_components &&
2368 		    get_entry_point().model == ExecutionModelFragment && storage == StorageClassOutput)
2369 		{
2370 			uint32_t locn = get_decoration(var.self, DecorationLocation) + i;
2371 			target_components = get_target_components_for_fragment_location(locn);
2372 			if (usable_type->vecsize < target_components)
2373 			{
2374 				// Make a new type here.
2375 				type_id = build_extended_vector_type(usable_type->self, target_components);
2376 				padded_output = true;
2377 			}
2378 		}
2379 
2380 		if (storage == StorageClassInput && pull_model_inputs.count(var.self))
2381 			ib_type.member_types.push_back(build_msl_interpolant_type(get_pointee_type_id(type_id), is_noperspective));
2382 		else
2383 			ib_type.member_types.push_back(get_pointee_type_id(type_id));
2384 
2385 		// Give the member a name
2386 		string mbr_name = ensure_valid_name(join(to_expression(var.self), "_", i), "m");
2387 		set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
2388 
2389 		// There is no qualified alias since we need to flatten the internal array on return.
2390 		if (get_decoration_bitset(var.self).get(DecorationLocation))
2391 		{
2392 			uint32_t locn = get_decoration(var.self, DecorationLocation) + i;
2393 			uint32_t comp = get_decoration(var.self, DecorationComponent);
2394 			if (storage == StorageClassInput)
2395 			{
2396 				var.basetype = ensure_correct_input_type(var.basetype, locn, 0, meta.strip_array);
2397 				uint32_t mbr_type_id = ensure_correct_input_type(usable_type->self, locn, 0, meta.strip_array);
2398 				if (storage == StorageClassInput && pull_model_inputs.count(var.self))
2399 					ib_type.member_types[ib_mbr_idx] = build_msl_interpolant_type(mbr_type_id, is_noperspective);
2400 				else
2401 					ib_type.member_types[ib_mbr_idx] = mbr_type_id;
2402 			}
2403 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2404 			if (comp)
2405 				set_member_decoration(ib_type.self, ib_mbr_idx, DecorationComponent, comp);
2406 			mark_location_as_used_by_shader(locn, *usable_type, storage);
2407 		}
2408 		else if (is_builtin && is_tessellation_shader() && inputs_by_builtin.count(builtin))
2409 		{
2410 			uint32_t locn = inputs_by_builtin[builtin].location + i;
2411 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2412 			mark_location_as_used_by_shader(locn, *usable_type, storage);
2413 		}
2414 		else if (is_builtin && (builtin == BuiltInClipDistance || builtin == BuiltInCullDistance))
2415 		{
2416 			// Declare the Clip/CullDistance as [[user(clip/cullN)]].
2417 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, builtin);
2418 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationIndex, i);
2419 		}
2420 
2421 		if (get_decoration_bitset(var.self).get(DecorationIndex))
2422 		{
2423 			uint32_t index = get_decoration(var.self, DecorationIndex);
2424 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationIndex, index);
2425 		}
2426 
2427 		if (storage != StorageClassInput || !pull_model_inputs.count(var.self))
2428 		{
2429 			// Copy interpolation decorations if needed
2430 			if (is_flat)
2431 				set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat);
2432 			if (is_noperspective)
2433 				set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective);
2434 			if (is_centroid)
2435 				set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid);
2436 			if (is_sample)
2437 				set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample);
2438 		}
2439 
2440 		set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceOrigID, var.self);
2441 
2442 		// Only flatten/unflatten IO composites for non-tessellation cases where arrays are not stripped.
2443 		if (!meta.strip_array)
2444 		{
2445 			switch (storage)
2446 			{
2447 			case StorageClassInput:
2448 				entry_func.fixup_hooks_in.push_back([=, &var]() {
2449 					if (pull_model_inputs.count(var.self))
2450 					{
2451 						string lerp_call;
2452 						if (is_centroid)
2453 							lerp_call = ".interpolate_at_centroid()";
2454 						else if (is_sample)
2455 							lerp_call = join(".interpolate_at_sample(", to_expression(builtin_sample_id_id), ")");
2456 						else
2457 							lerp_call = ".interpolate_at_center()";
2458 						statement(to_name(var.self), "[", i, "] = ", ib_var_ref, ".", mbr_name, lerp_call, ";");
2459 					}
2460 					else
2461 					{
2462 						statement(to_name(var.self), "[", i, "] = ", ib_var_ref, ".", mbr_name, ";");
2463 					}
2464 				});
2465 				break;
2466 
2467 			case StorageClassOutput:
2468 				entry_func.fixup_hooks_out.push_back([=, &var]() {
2469 					if (padded_output)
2470 					{
2471 						auto &padded_type = this->get<SPIRType>(type_id);
2472 						statement(
2473 						    ib_var_ref, ".", mbr_name, " = ",
2474 						    remap_swizzle(padded_type, usable_type->vecsize, join(to_name(var.self), "[", i, "]")),
2475 						    ";");
2476 					}
2477 					else if (flatten_from_ib_var)
2478 						statement(ib_var_ref, ".", mbr_name, " = ", ib_var_ref, ".", flatten_from_ib_mbr_name, "[", i,
2479 						          "];");
2480 					else
2481 						statement(ib_var_ref, ".", mbr_name, " = ", to_name(var.self), "[", i, "];");
2482 				});
2483 				break;
2484 
2485 			default:
2486 				break;
2487 			}
2488 		}
2489 	}
2490 }
2491 
add_composite_member_variable_to_interface_block(StorageClass storage,const string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var,uint32_t mbr_idx,InterfaceBlockMeta & meta)2492 void CompilerMSL::add_composite_member_variable_to_interface_block(StorageClass storage, const string &ib_var_ref,
2493                                                                    SPIRType &ib_type, SPIRVariable &var,
2494                                                                    uint32_t mbr_idx, InterfaceBlockMeta &meta)
2495 {
2496 	auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2497 	auto &var_type = meta.strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
2498 
2499 	BuiltIn builtin = BuiltInMax;
2500 	bool is_builtin = is_member_builtin(var_type, mbr_idx, &builtin);
2501 	bool is_flat =
2502 	    has_member_decoration(var_type.self, mbr_idx, DecorationFlat) || has_decoration(var.self, DecorationFlat);
2503 	bool is_noperspective = has_member_decoration(var_type.self, mbr_idx, DecorationNoPerspective) ||
2504 	                        has_decoration(var.self, DecorationNoPerspective);
2505 	bool is_centroid = has_member_decoration(var_type.self, mbr_idx, DecorationCentroid) ||
2506 	                   has_decoration(var.self, DecorationCentroid);
2507 	bool is_sample =
2508 	    has_member_decoration(var_type.self, mbr_idx, DecorationSample) || has_decoration(var.self, DecorationSample);
2509 
2510 	uint32_t mbr_type_id = var_type.member_types[mbr_idx];
2511 	auto &mbr_type = get<SPIRType>(mbr_type_id);
2512 	uint32_t elem_cnt = 0;
2513 
2514 	if (is_matrix(mbr_type))
2515 	{
2516 		if (is_array(mbr_type))
2517 			SPIRV_CROSS_THROW("MSL cannot emit arrays-of-matrices in input and output variables.");
2518 
2519 		elem_cnt = mbr_type.columns;
2520 	}
2521 	else if (is_array(mbr_type))
2522 	{
2523 		if (mbr_type.array.size() != 1)
2524 			SPIRV_CROSS_THROW("MSL cannot emit arrays-of-arrays in input and output variables.");
2525 
2526 		elem_cnt = to_array_size_literal(mbr_type);
2527 	}
2528 
2529 	auto *usable_type = &mbr_type;
2530 	if (usable_type->pointer)
2531 		usable_type = &get<SPIRType>(usable_type->parent_type);
2532 	while (is_array(*usable_type) || is_matrix(*usable_type))
2533 		usable_type = &get<SPIRType>(usable_type->parent_type);
2534 
2535 	bool flatten_from_ib_var = false;
2536 	string flatten_from_ib_mbr_name;
2537 
2538 	if (storage == StorageClassOutput && is_builtin && builtin == BuiltInClipDistance)
2539 	{
2540 		// Also declare [[clip_distance]] attribute here.
2541 		uint32_t clip_array_mbr_idx = uint32_t(ib_type.member_types.size());
2542 		ib_type.member_types.push_back(mbr_type_id);
2543 		set_member_decoration(ib_type.self, clip_array_mbr_idx, DecorationBuiltIn, BuiltInClipDistance);
2544 
2545 		flatten_from_ib_mbr_name = builtin_to_glsl(BuiltInClipDistance, StorageClassOutput);
2546 		set_member_name(ib_type.self, clip_array_mbr_idx, flatten_from_ib_mbr_name);
2547 
2548 		// When we flatten, we flatten directly from the "out" struct,
2549 		// not from a function variable.
2550 		flatten_from_ib_var = true;
2551 
2552 		if (!msl_options.enable_clip_distance_user_varying)
2553 			return;
2554 	}
2555 
2556 	for (uint32_t i = 0; i < elem_cnt; i++)
2557 	{
2558 		// Add a reference to the variable type to the interface struct.
2559 		uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
2560 		if (storage == StorageClassInput && pull_model_inputs.count(var.self))
2561 			ib_type.member_types.push_back(build_msl_interpolant_type(usable_type->self, is_noperspective));
2562 		else
2563 			ib_type.member_types.push_back(usable_type->self);
2564 
2565 		// Give the member a name
2566 		string mbr_name = ensure_valid_name(join(to_qualified_member_name(var_type, mbr_idx), "_", i), "m");
2567 		set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
2568 
2569 		if (has_member_decoration(var_type.self, mbr_idx, DecorationLocation))
2570 		{
2571 			uint32_t locn = get_member_decoration(var_type.self, mbr_idx, DecorationLocation) + i;
2572 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2573 			mark_location_as_used_by_shader(locn, *usable_type, storage);
2574 		}
2575 		else if (has_decoration(var.self, DecorationLocation))
2576 		{
2577 			uint32_t locn = get_accumulated_member_location(var, mbr_idx, meta.strip_array) + i;
2578 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2579 			mark_location_as_used_by_shader(locn, *usable_type, storage);
2580 		}
2581 		else if (is_builtin && is_tessellation_shader() && inputs_by_builtin.count(builtin))
2582 		{
2583 			uint32_t locn = inputs_by_builtin[builtin].location + i;
2584 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2585 			mark_location_as_used_by_shader(locn, *usable_type, storage);
2586 		}
2587 		else if (is_builtin && (builtin == BuiltInClipDistance || builtin == BuiltInCullDistance))
2588 		{
2589 			// Declare the Clip/CullDistance as [[user(clip/cullN)]].
2590 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, builtin);
2591 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationIndex, i);
2592 		}
2593 
2594 		if (has_member_decoration(var_type.self, mbr_idx, DecorationComponent))
2595 			SPIRV_CROSS_THROW("DecorationComponent on matrices and arrays make little sense.");
2596 
2597 		if (storage != StorageClassInput || !pull_model_inputs.count(var.self))
2598 		{
2599 			// Copy interpolation decorations if needed
2600 			if (is_flat)
2601 				set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat);
2602 			if (is_noperspective)
2603 				set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective);
2604 			if (is_centroid)
2605 				set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid);
2606 			if (is_sample)
2607 				set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample);
2608 		}
2609 
2610 		set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceOrigID, var.self);
2611 		set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex, mbr_idx);
2612 
2613 		// Unflatten or flatten from [[stage_in]] or [[stage_out]] as appropriate.
2614 		if (!meta.strip_array && meta.allow_local_declaration)
2615 		{
2616 			switch (storage)
2617 			{
2618 			case StorageClassInput:
2619 				entry_func.fixup_hooks_in.push_back([=, &var, &var_type]() {
2620 					if (pull_model_inputs.count(var.self))
2621 					{
2622 						string lerp_call;
2623 						if (is_centroid)
2624 							lerp_call = ".interpolate_at_centroid()";
2625 						else if (is_sample)
2626 							lerp_call = join(".interpolate_at_sample(", to_expression(builtin_sample_id_id), ")");
2627 						else
2628 							lerp_call = ".interpolate_at_center()";
2629 						statement(to_name(var.self), ".", to_member_name(var_type, mbr_idx), "[", i, "] = ", ib_var_ref,
2630 						          ".", mbr_name, lerp_call, ";");
2631 					}
2632 					else
2633 					{
2634 						statement(to_name(var.self), ".", to_member_name(var_type, mbr_idx), "[", i, "] = ", ib_var_ref,
2635 						          ".", mbr_name, ";");
2636 					}
2637 				});
2638 				break;
2639 
2640 			case StorageClassOutput:
2641 				entry_func.fixup_hooks_out.push_back([=, &var, &var_type]() {
2642 					if (flatten_from_ib_var)
2643 					{
2644 						statement(ib_var_ref, ".", mbr_name, " = ", ib_var_ref, ".", flatten_from_ib_mbr_name, "[", i,
2645 						          "];");
2646 					}
2647 					else
2648 					{
2649 						statement(ib_var_ref, ".", mbr_name, " = ", to_name(var.self), ".",
2650 						          to_member_name(var_type, mbr_idx), "[", i, "];");
2651 					}
2652 				});
2653 				break;
2654 
2655 			default:
2656 				break;
2657 			}
2658 		}
2659 	}
2660 }
2661 
add_plain_member_variable_to_interface_block(StorageClass storage,const string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var,uint32_t mbr_idx,InterfaceBlockMeta & meta)2662 void CompilerMSL::add_plain_member_variable_to_interface_block(StorageClass storage, const string &ib_var_ref,
2663                                                                SPIRType &ib_type, SPIRVariable &var, uint32_t mbr_idx,
2664                                                                InterfaceBlockMeta &meta)
2665 {
2666 	auto &var_type = meta.strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
2667 	auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2668 
2669 	BuiltIn builtin = BuiltInMax;
2670 	bool is_builtin = is_member_builtin(var_type, mbr_idx, &builtin);
2671 	bool is_flat =
2672 	    has_member_decoration(var_type.self, mbr_idx, DecorationFlat) || has_decoration(var.self, DecorationFlat);
2673 	bool is_noperspective = has_member_decoration(var_type.self, mbr_idx, DecorationNoPerspective) ||
2674 	                        has_decoration(var.self, DecorationNoPerspective);
2675 	bool is_centroid = has_member_decoration(var_type.self, mbr_idx, DecorationCentroid) ||
2676 	                   has_decoration(var.self, DecorationCentroid);
2677 	bool is_sample =
2678 	    has_member_decoration(var_type.self, mbr_idx, DecorationSample) || has_decoration(var.self, DecorationSample);
2679 
2680 	// Add a reference to the member to the interface struct.
2681 	uint32_t mbr_type_id = var_type.member_types[mbr_idx];
2682 	uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
2683 	mbr_type_id = ensure_correct_builtin_type(mbr_type_id, builtin);
2684 	var_type.member_types[mbr_idx] = mbr_type_id;
2685 	if (storage == StorageClassInput && pull_model_inputs.count(var.self))
2686 		ib_type.member_types.push_back(build_msl_interpolant_type(mbr_type_id, is_noperspective));
2687 	else
2688 		ib_type.member_types.push_back(mbr_type_id);
2689 
2690 	// Give the member a name
2691 	string mbr_name = ensure_valid_name(to_qualified_member_name(var_type, mbr_idx), "m");
2692 	set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
2693 
2694 	// Update the original variable reference to include the structure reference
2695 	string qual_var_name = ib_var_ref + "." + mbr_name;
2696 	// If using pull-model interpolation, need to add a call to the correct interpolation method.
2697 	if (storage == StorageClassInput && pull_model_inputs.count(var.self))
2698 	{
2699 		if (is_centroid)
2700 			qual_var_name += ".interpolate_at_centroid()";
2701 		else if (is_sample)
2702 			qual_var_name += join(".interpolate_at_sample(", to_expression(builtin_sample_id_id), ")");
2703 		else
2704 			qual_var_name += ".interpolate_at_center()";
2705 	}
2706 
2707 	bool flatten_stage_out = false;
2708 
2709 	if (is_builtin && !meta.strip_array)
2710 	{
2711 		// For the builtin gl_PerVertex, we cannot treat it as a block anyways,
2712 		// so redirect to qualified name.
2713 		set_member_qualified_name(var_type.self, mbr_idx, qual_var_name);
2714 	}
2715 	else if (!meta.strip_array && meta.allow_local_declaration)
2716 	{
2717 		// Unflatten or flatten from [[stage_in]] or [[stage_out]] as appropriate.
2718 		switch (storage)
2719 		{
2720 		case StorageClassInput:
2721 			entry_func.fixup_hooks_in.push_back([=, &var, &var_type]() {
2722 				statement(to_name(var.self), ".", to_member_name(var_type, mbr_idx), " = ", qual_var_name, ";");
2723 			});
2724 			break;
2725 
2726 		case StorageClassOutput:
2727 			flatten_stage_out = true;
2728 			entry_func.fixup_hooks_out.push_back([=, &var, &var_type]() {
2729 				statement(qual_var_name, " = ", to_name(var.self), ".", to_member_name(var_type, mbr_idx), ";");
2730 			});
2731 			break;
2732 
2733 		default:
2734 			break;
2735 		}
2736 	}
2737 
2738 	// Copy the variable location from the original variable to the member
2739 	if (has_member_decoration(var_type.self, mbr_idx, DecorationLocation))
2740 	{
2741 		uint32_t locn = get_member_decoration(var_type.self, mbr_idx, DecorationLocation);
2742 		if (storage == StorageClassInput)
2743 		{
2744 			mbr_type_id = ensure_correct_input_type(mbr_type_id, locn, 0, meta.strip_array);
2745 			var_type.member_types[mbr_idx] = mbr_type_id;
2746 			if (storage == StorageClassInput && pull_model_inputs.count(var.self))
2747 				ib_type.member_types[ib_mbr_idx] = build_msl_interpolant_type(mbr_type_id, is_noperspective);
2748 			else
2749 				ib_type.member_types[ib_mbr_idx] = mbr_type_id;
2750 		}
2751 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2752 		mark_location_as_used_by_shader(locn, get<SPIRType>(mbr_type_id), storage);
2753 	}
2754 	else if (has_decoration(var.self, DecorationLocation))
2755 	{
2756 		// The block itself might have a location and in this case, all members of the block
2757 		// receive incrementing locations.
2758 		uint32_t locn = get_accumulated_member_location(var, mbr_idx, meta.strip_array);
2759 		if (storage == StorageClassInput)
2760 		{
2761 			mbr_type_id = ensure_correct_input_type(mbr_type_id, locn, 0, meta.strip_array);
2762 			var_type.member_types[mbr_idx] = mbr_type_id;
2763 			if (storage == StorageClassInput && pull_model_inputs.count(var.self))
2764 				ib_type.member_types[ib_mbr_idx] = build_msl_interpolant_type(mbr_type_id, is_noperspective);
2765 			else
2766 				ib_type.member_types[ib_mbr_idx] = mbr_type_id;
2767 		}
2768 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2769 		mark_location_as_used_by_shader(locn, get<SPIRType>(mbr_type_id), storage);
2770 	}
2771 	else if (is_builtin && is_tessellation_shader() && inputs_by_builtin.count(builtin))
2772 	{
2773 		uint32_t locn = 0;
2774 		auto builtin_itr = inputs_by_builtin.find(builtin);
2775 		if (builtin_itr != end(inputs_by_builtin))
2776 			locn = builtin_itr->second.location;
2777 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2778 		mark_location_as_used_by_shader(locn, get<SPIRType>(mbr_type_id), storage);
2779 	}
2780 
2781 	// Copy the component location, if present.
2782 	if (has_member_decoration(var_type.self, mbr_idx, DecorationComponent))
2783 	{
2784 		uint32_t comp = get_member_decoration(var_type.self, mbr_idx, DecorationComponent);
2785 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationComponent, comp);
2786 	}
2787 
2788 	// Mark the member as builtin if needed
2789 	if (is_builtin)
2790 	{
2791 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, builtin);
2792 		if (builtin == BuiltInPosition && storage == StorageClassOutput)
2793 			qual_pos_var_name = qual_var_name;
2794 	}
2795 
2796 	const SPIRConstant *c = nullptr;
2797 	if (!flatten_stage_out && var.storage == StorageClassOutput &&
2798 	    var.initializer != ID(0) && (c = maybe_get<SPIRConstant>(var.initializer)))
2799 	{
2800 		if (meta.strip_array)
2801 		{
2802 			entry_func.fixup_hooks_in.push_back([=, &var]() {
2803 				auto &type = this->get<SPIRType>(var.basetype);
2804 				uint32_t index = get_extended_member_decoration(var.self, mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex);
2805 
2806 				auto invocation = to_tesc_invocation_id();
2807 				auto constant_chain = join(to_expression(var.initializer), "[", invocation, "]");
2808 				statement(to_expression(stage_out_ptr_var_id), "[",
2809 				          invocation, "].",
2810 				          to_member_name(ib_type, index), " = ",
2811 				          constant_chain, ".", to_member_name(type, mbr_idx), ";");
2812 			});
2813 		}
2814 		else
2815 		{
2816 			entry_func.fixup_hooks_in.push_back([=]() {
2817 				statement(qual_var_name, " = ", constant_expression(
2818 						this->get<SPIRConstant>(c->subconstants[mbr_idx])), ";");
2819 			});
2820 		}
2821 	}
2822 
2823 	if (storage != StorageClassInput || !pull_model_inputs.count(var.self))
2824 	{
2825 		// Copy interpolation decorations if needed
2826 		if (is_flat)
2827 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat);
2828 		if (is_noperspective)
2829 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective);
2830 		if (is_centroid)
2831 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid);
2832 		if (is_sample)
2833 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample);
2834 	}
2835 
2836 	set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceOrigID, var.self);
2837 	set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex, mbr_idx);
2838 }
2839 
2840 // In Metal, the tessellation levels are stored as tightly packed half-precision floating point values.
2841 // But, stage-in attribute offsets and strides must be multiples of four, so we can't pass the levels
2842 // individually. Therefore, we must pass them as vectors. Triangles get a single float4, with the outer
2843 // levels in 'xyz' and the inner level in 'w'. Quads get a float4 containing the outer levels and a
2844 // float2 containing the inner levels.
add_tess_level_input_to_interface_block(const std::string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var)2845 void CompilerMSL::add_tess_level_input_to_interface_block(const std::string &ib_var_ref, SPIRType &ib_type,
2846                                                           SPIRVariable &var)
2847 {
2848 	auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2849 	auto &var_type = get_variable_element_type(var);
2850 
2851 	BuiltIn builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
2852 
2853 	// Force the variable to have the proper name.
2854 	string var_name = builtin_to_glsl(builtin, StorageClassFunction);
2855 	set_name(var.self, var_name);
2856 
2857 	// We need to declare the variable early and at entry-point scope.
2858 	entry_func.add_local_variable(var.self);
2859 	vars_needing_early_declaration.push_back(var.self);
2860 	bool triangles = get_execution_mode_bitset().get(ExecutionModeTriangles);
2861 	string mbr_name;
2862 
2863 	// Add a reference to the variable type to the interface struct.
2864 	uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
2865 
2866 	const auto mark_locations = [&](const SPIRType &new_var_type) {
2867 		if (get_decoration_bitset(var.self).get(DecorationLocation))
2868 		{
2869 			uint32_t locn = get_decoration(var.self, DecorationLocation);
2870 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2871 			mark_location_as_used_by_shader(locn, new_var_type, StorageClassInput);
2872 		}
2873 		else if (inputs_by_builtin.count(builtin))
2874 		{
2875 			uint32_t locn = inputs_by_builtin[builtin].location;
2876 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2877 			mark_location_as_used_by_shader(locn, new_var_type, StorageClassInput);
2878 		}
2879 	};
2880 
2881 	if (triangles)
2882 	{
2883 		// Triangles are tricky, because we want only one member in the struct.
2884 		mbr_name = "gl_TessLevel";
2885 
2886 		// If we already added the other one, we can skip this step.
2887 		if (!added_builtin_tess_level)
2888 		{
2889 			uint32_t type_id = build_extended_vector_type(var_type.self, 4);
2890 
2891 			ib_type.member_types.push_back(type_id);
2892 
2893 			// Give the member a name
2894 			set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
2895 
2896 			// We cannot decorate both, but the important part is that
2897 			// it's marked as builtin so we can get automatic attribute assignment if needed.
2898 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, builtin);
2899 
2900 			mark_locations(var_type);
2901 			added_builtin_tess_level = true;
2902 		}
2903 	}
2904 	else
2905 	{
2906 		mbr_name = var_name;
2907 
2908 		uint32_t type_id = build_extended_vector_type(var_type.self, builtin == BuiltInTessLevelOuter ? 4 : 2);
2909 
2910 		uint32_t ptr_type_id = ir.increase_bound_by(1);
2911 		auto &new_var_type = set<SPIRType>(ptr_type_id, get<SPIRType>(type_id));
2912 		new_var_type.pointer = true;
2913 		new_var_type.pointer_depth++;
2914 		new_var_type.storage = StorageClassInput;
2915 		new_var_type.parent_type = type_id;
2916 
2917 		ib_type.member_types.push_back(type_id);
2918 
2919 		// Give the member a name
2920 		set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
2921 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, builtin);
2922 
2923 		mark_locations(new_var_type);
2924 	}
2925 
2926 	if (builtin == BuiltInTessLevelOuter)
2927 	{
2928 		entry_func.fixup_hooks_in.push_back([=]() {
2929 			statement(var_name, "[0] = ", ib_var_ref, ".", mbr_name, ".x;");
2930 			statement(var_name, "[1] = ", ib_var_ref, ".", mbr_name, ".y;");
2931 			statement(var_name, "[2] = ", ib_var_ref, ".", mbr_name, ".z;");
2932 			if (!triangles)
2933 				statement(var_name, "[3] = ", ib_var_ref, ".", mbr_name, ".w;");
2934 		});
2935 	}
2936 	else
2937 	{
2938 		entry_func.fixup_hooks_in.push_back([=]() {
2939 			if (triangles)
2940 			{
2941 				statement(var_name, "[0] = ", ib_var_ref, ".", mbr_name, ".w;");
2942 			}
2943 			else
2944 			{
2945 				statement(var_name, "[0] = ", ib_var_ref, ".", mbr_name, ".x;");
2946 				statement(var_name, "[1] = ", ib_var_ref, ".", mbr_name, ".y;");
2947 			}
2948 		});
2949 	}
2950 }
2951 
variable_storage_requires_stage_io(spv::StorageClass storage) const2952 bool CompilerMSL::variable_storage_requires_stage_io(spv::StorageClass storage) const
2953 {
2954 	if (storage == StorageClassOutput)
2955 		return !capture_output_to_buffer;
2956 	else if (storage == StorageClassInput)
2957 		return !(get_execution_model() == ExecutionModelTessellationControl && msl_options.multi_patch_workgroup);
2958 	else
2959 		return false;
2960 }
2961 
to_tesc_invocation_id()2962 string CompilerMSL::to_tesc_invocation_id()
2963 {
2964 	if (msl_options.multi_patch_workgroup)
2965 	{
2966 		// n.b. builtin_invocation_id_id here is the dispatch global invocation ID,
2967 		// not the TC invocation ID.
2968 		return join(to_expression(builtin_invocation_id_id), ".x % ", get_entry_point().output_vertices);
2969 	}
2970 	else
2971 		return builtin_to_glsl(BuiltInInvocationId, StorageClassInput);
2972 }
2973 
emit_local_masked_variable(const SPIRVariable & masked_var,bool strip_array)2974 void CompilerMSL::emit_local_masked_variable(const SPIRVariable &masked_var, bool strip_array)
2975 {
2976 	auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2977 	bool threadgroup_storage = variable_decl_is_remapped_storage(masked_var, StorageClassWorkgroup);
2978 
2979 	if (threadgroup_storage && msl_options.multi_patch_workgroup)
2980 	{
2981 		// We need one threadgroup block per patch, so fake this.
2982 		entry_func.fixup_hooks_in.push_back([this, &masked_var]() {
2983 			auto &type = get_variable_data_type(masked_var);
2984 			add_local_variable_name(masked_var.self);
2985 
2986 			bool old_is_builtin = is_using_builtin_array;
2987 			is_using_builtin_array = true;
2988 
2989 			const uint32_t max_control_points_per_patch = 32u;
2990 			uint32_t max_num_instances =
2991 					(max_control_points_per_patch + get_entry_point().output_vertices - 1u) /
2992 					get_entry_point().output_vertices;
2993 			statement("threadgroup ", type_to_glsl(type), " ",
2994 			          "spvStorage", to_name(masked_var.self), "[", max_num_instances, "]",
2995 			          type_to_array_glsl(type), ";");
2996 
2997 			// Assign a threadgroup slice to each PrimitiveID.
2998 			// We assume here that workgroup size is rounded to 32,
2999 			// since that's the maximum number of control points per patch.
3000 			// We cannot size the array based on fixed dispatch parameters,
3001 			// since Metal does not allow that. :(
3002 			// FIXME: We will likely need an option to support passing down target workgroup size,
3003 			// so we can emit appropriate size here.
3004 			statement("threadgroup ", type_to_glsl(type), " ",
3005 			          "(&", to_name(masked_var.self), ")",
3006 			          type_to_array_glsl(type), " = spvStorage", to_name(masked_var.self), "[",
3007 			          "(", to_expression(builtin_invocation_id_id), ".x / ",
3008 			          get_entry_point().output_vertices, ") % ",
3009 			          max_num_instances, "];");
3010 
3011 			is_using_builtin_array = old_is_builtin;
3012 		});
3013 	}
3014 	else
3015 	{
3016 		entry_func.add_local_variable(masked_var.self);
3017 	}
3018 
3019 	if (!threadgroup_storage)
3020 	{
3021 		vars_needing_early_declaration.push_back(masked_var.self);
3022 	}
3023 	else if (masked_var.initializer)
3024 	{
3025 		// Cannot directly initialize threadgroup variables. Need fixup hooks.
3026 		ID initializer = masked_var.initializer;
3027 		if (strip_array)
3028 		{
3029 			entry_func.fixup_hooks_in.push_back([this, &masked_var, initializer]() {
3030 				auto invocation = to_tesc_invocation_id();
3031 				statement(to_expression(masked_var.self), "[",
3032 				          invocation, "] = ",
3033 				          to_expression(initializer), "[",
3034 				          invocation, "];");
3035 			});
3036 		}
3037 		else
3038 		{
3039 			entry_func.fixup_hooks_in.push_back([this, &masked_var, initializer]() {
3040 				statement(to_expression(masked_var.self), " = ", to_expression(initializer), ";");
3041 			});
3042 		}
3043 	}
3044 }
3045 
add_variable_to_interface_block(StorageClass storage,const string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var,InterfaceBlockMeta & meta)3046 void CompilerMSL::add_variable_to_interface_block(StorageClass storage, const string &ib_var_ref, SPIRType &ib_type,
3047                                                   SPIRVariable &var, InterfaceBlockMeta &meta)
3048 {
3049 	auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
3050 	// Tessellation control I/O variables and tessellation evaluation per-point inputs are
3051 	// usually declared as arrays. In these cases, we want to add the element type to the
3052 	// interface block, since in Metal it's the interface block itself which is arrayed.
3053 	auto &var_type = meta.strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
3054 	bool is_builtin = is_builtin_variable(var);
3055 	auto builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
3056 	bool is_block = has_decoration(var_type.self, DecorationBlock);
3057 
3058 	// If stage variables are masked out, emit them as plain variables instead.
3059 	// For builtins, we query them one by one later.
3060 	// IO blocks are not masked here, we need to mask them per-member instead.
3061 	if (storage == StorageClassOutput && is_stage_output_variable_masked(var))
3062 	{
3063 		// If we ignore an output, we must still emit it, since it might be used by app.
3064 		// Instead, just emit it as early declaration.
3065 		emit_local_masked_variable(var, meta.strip_array);
3066 		return;
3067 	}
3068 
3069 	if (var_type.basetype == SPIRType::Struct)
3070 	{
3071 		bool block_requires_flattening = variable_storage_requires_stage_io(storage) || is_block;
3072 		bool needs_local_declaration = !is_builtin && block_requires_flattening && meta.allow_local_declaration;
3073 
3074 		if (needs_local_declaration)
3075 		{
3076 			// For I/O blocks or structs, we will need to pass the block itself around
3077 			// to functions if they are used globally in leaf functions.
3078 			// Rather than passing down member by member,
3079 			// we unflatten I/O blocks while running the shader,
3080 			// and pass the actual struct type down to leaf functions.
3081 			// We then unflatten inputs, and flatten outputs in the "fixup" stages.
3082 			emit_local_masked_variable(var, meta.strip_array);
3083 		}
3084 
3085 		if (!block_requires_flattening)
3086 		{
3087 			// In Metal tessellation shaders, the interface block itself is arrayed. This makes things
3088 			// very complicated, since stage-in structures in MSL don't support nested structures.
3089 			// Luckily, for stage-out when capturing output, we can avoid this and just add
3090 			// composite members directly, because the stage-out structure is stored to a buffer,
3091 			// not returned.
3092 			add_plain_variable_to_interface_block(storage, ib_var_ref, ib_type, var, meta);
3093 		}
3094 		else
3095 		{
3096 			bool masked_block = false;
3097 
3098 			// Flatten the struct members into the interface struct
3099 			for (uint32_t mbr_idx = 0; mbr_idx < uint32_t(var_type.member_types.size()); mbr_idx++)
3100 			{
3101 				builtin = BuiltInMax;
3102 				is_builtin = is_member_builtin(var_type, mbr_idx, &builtin);
3103 				auto &mbr_type = get<SPIRType>(var_type.member_types[mbr_idx]);
3104 
3105 				if (storage == StorageClassOutput && is_stage_output_block_member_masked(var, mbr_idx, meta.strip_array))
3106 				{
3107 					if (is_block)
3108 						masked_block = true;
3109 
3110 					// Non-builtin block output variables are just ignored, since they will still access
3111 					// the block variable as-is. They're just not flattened.
3112 					if (is_builtin && !meta.strip_array)
3113 					{
3114 						// Emit a fake variable instead.
3115 						uint32_t ids = ir.increase_bound_by(2);
3116 						uint32_t ptr_type_id = ids + 0;
3117 						uint32_t var_id = ids + 1;
3118 
3119 						auto ptr_type = mbr_type;
3120 						ptr_type.pointer = true;
3121 						ptr_type.pointer_depth++;
3122 						ptr_type.parent_type = var_type.member_types[mbr_idx];
3123 						ptr_type.storage = StorageClassOutput;
3124 
3125 						uint32_t initializer = 0;
3126 						if (var.initializer)
3127 							if (auto *c = maybe_get<SPIRConstant>(var.initializer))
3128 								initializer = c->subconstants[mbr_idx];
3129 
3130 						set<SPIRType>(ptr_type_id, ptr_type);
3131 						set<SPIRVariable>(var_id, ptr_type_id, StorageClassOutput, initializer);
3132 						entry_func.add_local_variable(var_id);
3133 						vars_needing_early_declaration.push_back(var_id);
3134 						set_name(var_id, builtin_to_glsl(builtin, StorageClassOutput));
3135 						set_decoration(var_id, DecorationBuiltIn, builtin);
3136 					}
3137 				}
3138 				else if (!is_builtin || has_active_builtin(builtin, storage))
3139 				{
3140 					bool is_composite_type = is_matrix(mbr_type) || is_array(mbr_type);
3141 					bool attribute_load_store =
3142 					    storage == StorageClassInput && get_execution_model() != ExecutionModelFragment;
3143 					bool storage_is_stage_io = variable_storage_requires_stage_io(storage);
3144 
3145 					// Clip/CullDistance always need to be declared as user attributes.
3146 					if (builtin == BuiltInClipDistance || builtin == BuiltInCullDistance)
3147 						is_builtin = false;
3148 
3149 					if ((!is_builtin || attribute_load_store) && storage_is_stage_io && is_composite_type)
3150 					{
3151 						add_composite_member_variable_to_interface_block(storage, ib_var_ref, ib_type, var, mbr_idx,
3152 						                                                 meta);
3153 					}
3154 					else
3155 					{
3156 						add_plain_member_variable_to_interface_block(storage, ib_var_ref, ib_type, var, mbr_idx, meta);
3157 					}
3158 				}
3159 			}
3160 
3161 			// If we're redirecting a block, we might still need to access the original block
3162 			// variable if we're masking some members.
3163 			if (masked_block && !needs_local_declaration &&
3164 			    (!is_builtin_variable(var) || get_execution_model() == ExecutionModelTessellationControl))
3165 			{
3166 				if (is_builtin_variable(var))
3167 				{
3168 					// Ensure correct names for the block members if we're actually going to
3169 					// declare gl_PerVertex.
3170 					for (uint32_t mbr_idx = 0; mbr_idx < uint32_t(var_type.member_types.size()); mbr_idx++)
3171 					{
3172 						set_member_name(var_type.self, mbr_idx, builtin_to_glsl(
3173 								BuiltIn(get_member_decoration(var_type.self, mbr_idx, DecorationBuiltIn)),
3174 								StorageClassOutput));
3175 					}
3176 
3177 					set_name(var_type.self, "gl_PerVertex");
3178 					set_name(var.self, "gl_out_masked");
3179 					stage_out_masked_builtin_type_id = var_type.self;
3180 				}
3181 				emit_local_masked_variable(var, meta.strip_array);
3182 			}
3183 		}
3184 	}
3185 	else if (get_execution_model() == ExecutionModelTessellationEvaluation && storage == StorageClassInput &&
3186 	         !meta.strip_array && is_builtin && (builtin == BuiltInTessLevelOuter || builtin == BuiltInTessLevelInner))
3187 	{
3188 		add_tess_level_input_to_interface_block(ib_var_ref, ib_type, var);
3189 	}
3190 	else if (var_type.basetype == SPIRType::Boolean || var_type.basetype == SPIRType::Char ||
3191 	         type_is_integral(var_type) || type_is_floating_point(var_type))
3192 	{
3193 		if (!is_builtin || has_active_builtin(builtin, storage))
3194 		{
3195 			bool is_composite_type = is_matrix(var_type) || is_array(var_type);
3196 			bool storage_is_stage_io = variable_storage_requires_stage_io(storage);
3197 			bool attribute_load_store = storage == StorageClassInput && get_execution_model() != ExecutionModelFragment;
3198 
3199 			// Clip/CullDistance always needs to be declared as user attributes.
3200 			if (builtin == BuiltInClipDistance || builtin == BuiltInCullDistance)
3201 				is_builtin = false;
3202 
3203 			// MSL does not allow matrices or arrays in input or output variables, so need to handle it specially.
3204 			if ((!is_builtin || attribute_load_store) && storage_is_stage_io && is_composite_type)
3205 			{
3206 				add_composite_variable_to_interface_block(storage, ib_var_ref, ib_type, var, meta);
3207 			}
3208 			else
3209 			{
3210 				add_plain_variable_to_interface_block(storage, ib_var_ref, ib_type, var, meta);
3211 			}
3212 		}
3213 	}
3214 }
3215 
3216 // Fix up the mapping of variables to interface member indices, which is used to compile access chains
3217 // for per-vertex variables in a tessellation control shader.
fix_up_interface_member_indices(StorageClass storage,uint32_t ib_type_id)3218 void CompilerMSL::fix_up_interface_member_indices(StorageClass storage, uint32_t ib_type_id)
3219 {
3220 	// Only needed for tessellation shaders and pull-model interpolants.
3221 	// Need to redirect interface indices back to variables themselves.
3222 	// For structs, each member of the struct need a separate instance.
3223 	if (get_execution_model() != ExecutionModelTessellationControl &&
3224 	    !(get_execution_model() == ExecutionModelTessellationEvaluation && storage == StorageClassInput) &&
3225 	    !(get_execution_model() == ExecutionModelFragment && storage == StorageClassInput &&
3226 	      !pull_model_inputs.empty()))
3227 		return;
3228 
3229 	auto mbr_cnt = uint32_t(ir.meta[ib_type_id].members.size());
3230 	for (uint32_t i = 0; i < mbr_cnt; i++)
3231 	{
3232 		uint32_t var_id = get_extended_member_decoration(ib_type_id, i, SPIRVCrossDecorationInterfaceOrigID);
3233 		if (!var_id)
3234 			continue;
3235 		auto &var = get<SPIRVariable>(var_id);
3236 
3237 		auto &type = get_variable_element_type(var);
3238 
3239 		bool flatten_composites = variable_storage_requires_stage_io(var.storage);
3240 		bool is_block = has_decoration(type.self, DecorationBlock);
3241 
3242 		uint32_t mbr_idx = uint32_t(-1);
3243 		if (type.basetype == SPIRType::Struct && (flatten_composites || is_block))
3244 			mbr_idx = get_extended_member_decoration(ib_type_id, i, SPIRVCrossDecorationInterfaceMemberIndex);
3245 
3246 		if (mbr_idx != uint32_t(-1))
3247 		{
3248 			// Only set the lowest InterfaceMemberIndex for each variable member.
3249 			// IB struct members will be emitted in-order w.r.t. interface member index.
3250 			if (!has_extended_member_decoration(var_id, mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex))
3251 				set_extended_member_decoration(var_id, mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex, i);
3252 		}
3253 		else
3254 		{
3255 			// Only set the lowest InterfaceMemberIndex for each variable.
3256 			// IB struct members will be emitted in-order w.r.t. interface member index.
3257 			if (!has_extended_decoration(var_id, SPIRVCrossDecorationInterfaceMemberIndex))
3258 				set_extended_decoration(var_id, SPIRVCrossDecorationInterfaceMemberIndex, i);
3259 		}
3260 	}
3261 }
3262 
3263 // Add an interface structure for the type of storage, which is either StorageClassInput or StorageClassOutput.
3264 // Returns the ID of the newly added variable, or zero if no variable was added.
add_interface_block(StorageClass storage,bool patch)3265 uint32_t CompilerMSL::add_interface_block(StorageClass storage, bool patch)
3266 {
3267 	// Accumulate the variables that should appear in the interface struct.
3268 	SmallVector<SPIRVariable *> vars;
3269 	bool incl_builtins = storage == StorageClassOutput || is_tessellation_shader();
3270 	bool has_seen_barycentric = false;
3271 
3272 	InterfaceBlockMeta meta;
3273 
3274 	// Varying interfaces between stages which use "user()" attribute can be dealt with
3275 	// without explicit packing and unpacking of components. For any variables which link against the runtime
3276 	// in some way (vertex attributes, fragment output, etc), we'll need to deal with it somehow.
3277 	bool pack_components =
3278 	    (storage == StorageClassInput && get_execution_model() == ExecutionModelVertex) ||
3279 	    (storage == StorageClassOutput && get_execution_model() == ExecutionModelFragment) ||
3280 	    (storage == StorageClassOutput && get_execution_model() == ExecutionModelVertex && capture_output_to_buffer);
3281 
3282 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t var_id, SPIRVariable &var) {
3283 		if (var.storage != storage)
3284 			return;
3285 
3286 		auto &type = this->get<SPIRType>(var.basetype);
3287 
3288 		bool is_builtin = is_builtin_variable(var);
3289 		bool is_block = has_decoration(type.self, DecorationBlock);
3290 
3291 		auto bi_type = BuiltInMax;
3292 		bool builtin_is_gl_in_out = false;
3293 		if (is_builtin && !is_block)
3294 		{
3295 			bi_type = BuiltIn(get_decoration(var_id, DecorationBuiltIn));
3296 			builtin_is_gl_in_out = bi_type == BuiltInPosition || bi_type == BuiltInPointSize ||
3297 			                       bi_type == BuiltInClipDistance || bi_type == BuiltInCullDistance;
3298 		}
3299 
3300 		if (is_builtin && is_block)
3301 			builtin_is_gl_in_out = true;
3302 
3303 		uint32_t location = get_decoration(var_id, DecorationLocation);
3304 
3305 		bool builtin_is_stage_in_out = builtin_is_gl_in_out ||
3306 		                               bi_type == BuiltInLayer || bi_type == BuiltInViewportIndex ||
3307 		                               bi_type == BuiltInBaryCoordNV || bi_type == BuiltInBaryCoordNoPerspNV ||
3308 		                               bi_type == BuiltInFragDepth ||
3309 		                               bi_type == BuiltInFragStencilRefEXT || bi_type == BuiltInSampleMask;
3310 
3311 		// These builtins are part of the stage in/out structs.
3312 		bool is_interface_block_builtin =
3313 				builtin_is_stage_in_out ||
3314 				(get_execution_model() == ExecutionModelTessellationEvaluation &&
3315 				 (bi_type == BuiltInTessLevelOuter || bi_type == BuiltInTessLevelInner));
3316 
3317 		bool is_active = interface_variable_exists_in_entry_point(var.self);
3318 		if (is_builtin && is_active)
3319 		{
3320 			// Only emit the builtin if it's active in this entry point. Interface variable list might lie.
3321 			if (is_block)
3322 			{
3323 				// If any builtin is active, the block is active.
3324 				uint32_t mbr_cnt = uint32_t(type.member_types.size());
3325 				for (uint32_t i = 0; !is_active && i < mbr_cnt; i++)
3326 					is_active = has_active_builtin(BuiltIn(get_member_decoration(type.self, i, DecorationBuiltIn)), storage);
3327 			}
3328 			else
3329 			{
3330 				is_active = has_active_builtin(bi_type, storage);
3331 			}
3332 		}
3333 
3334 		bool filter_patch_decoration = (has_decoration(var_id, DecorationPatch) || is_patch_block(type)) == patch;
3335 
3336 		bool hidden = is_hidden_variable(var, incl_builtins);
3337 
3338 		// ClipDistance is never hidden, we need to emulate it when used as an input.
3339 		if (bi_type == BuiltInClipDistance || bi_type == BuiltInCullDistance)
3340 			hidden = false;
3341 
3342 		// It's not enough to simply avoid marking fragment outputs if the pipeline won't
3343 		// accept them. We can't put them in the struct at all, or otherwise the compiler
3344 		// complains that the outputs weren't explicitly marked.
3345 		if (get_execution_model() == ExecutionModelFragment && storage == StorageClassOutput && !patch &&
3346 		    ((is_builtin && ((bi_type == BuiltInFragDepth && !msl_options.enable_frag_depth_builtin) ||
3347 		                     (bi_type == BuiltInFragStencilRefEXT && !msl_options.enable_frag_stencil_ref_builtin))) ||
3348 		     (!is_builtin && !(msl_options.enable_frag_output_mask & (1 << location)))))
3349 		{
3350 			hidden = true;
3351 			disabled_frag_outputs.push_back(var_id);
3352 			// If a builtin, force it to have the proper name.
3353 			if (is_builtin)
3354 				set_name(var_id, builtin_to_glsl(bi_type, StorageClassFunction));
3355 		}
3356 
3357 		// Barycentric inputs must be emitted in stage-in, because they can have interpolation arguments.
3358 		if (is_active && (bi_type == BuiltInBaryCoordNV || bi_type == BuiltInBaryCoordNoPerspNV))
3359 		{
3360 			if (has_seen_barycentric)
3361 				SPIRV_CROSS_THROW("Cannot declare both BaryCoordNV and BaryCoordNoPerspNV in same shader in MSL.");
3362 			has_seen_barycentric = true;
3363 			hidden = false;
3364 		}
3365 
3366 		if (is_active && !hidden && type.pointer && filter_patch_decoration &&
3367 		    (!is_builtin || is_interface_block_builtin))
3368 		{
3369 			vars.push_back(&var);
3370 
3371 			if (!is_builtin)
3372 			{
3373 				// Need to deal specially with DecorationComponent.
3374 				// Multiple variables can alias the same Location, and try to make sure each location is declared only once.
3375 				// We will swizzle data in and out to make this work.
3376 				// This is only relevant for vertex inputs and fragment outputs.
3377 				// Technically tessellation as well, but it is too complicated to support.
3378 				uint32_t component = get_decoration(var_id, DecorationComponent);
3379 				if (component != 0)
3380 				{
3381 					if (is_tessellation_shader())
3382 						SPIRV_CROSS_THROW("Component decoration is not supported in tessellation shaders.");
3383 					else if (pack_components)
3384 					{
3385 						uint32_t array_size = 1;
3386 						if (!type.array.empty())
3387 							array_size = to_array_size_literal(type);
3388 
3389 						for (uint32_t location_offset = 0; location_offset < array_size; location_offset++)
3390 						{
3391 							auto &location_meta = meta.location_meta[location + location_offset];
3392 							location_meta.num_components = std::max(location_meta.num_components, component + type.vecsize);
3393 
3394 							// For variables sharing location, decorations and base type must match.
3395 							location_meta.base_type_id = type.self;
3396 							location_meta.flat = has_decoration(var.self, DecorationFlat);
3397 							location_meta.noperspective = has_decoration(var.self, DecorationNoPerspective);
3398 							location_meta.centroid = has_decoration(var.self, DecorationCentroid);
3399 							location_meta.sample = has_decoration(var.self, DecorationSample);
3400 						}
3401 					}
3402 				}
3403 			}
3404 		}
3405 	});
3406 
3407 	// If no variables qualify, leave.
3408 	// For patch input in a tessellation evaluation shader, the per-vertex stage inputs
3409 	// are included in a special patch control point array.
3410 	if (vars.empty() && !(storage == StorageClassInput && patch && stage_in_var_id))
3411 		return 0;
3412 
3413 	// Add a new typed variable for this interface structure.
3414 	// The initializer expression is allocated here, but populated when the function
3415 	// declaraion is emitted, because it is cleared after each compilation pass.
3416 	uint32_t next_id = ir.increase_bound_by(3);
3417 	uint32_t ib_type_id = next_id++;
3418 	auto &ib_type = set<SPIRType>(ib_type_id);
3419 	ib_type.basetype = SPIRType::Struct;
3420 	ib_type.storage = storage;
3421 	set_decoration(ib_type_id, DecorationBlock);
3422 
3423 	uint32_t ib_var_id = next_id++;
3424 	auto &var = set<SPIRVariable>(ib_var_id, ib_type_id, storage, 0);
3425 	var.initializer = next_id++;
3426 
3427 	string ib_var_ref;
3428 	auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
3429 	switch (storage)
3430 	{
3431 	case StorageClassInput:
3432 		ib_var_ref = patch ? patch_stage_in_var_name : stage_in_var_name;
3433 		if (get_execution_model() == ExecutionModelTessellationControl)
3434 		{
3435 			// Add a hook to populate the shared workgroup memory containing the gl_in array.
3436 			entry_func.fixup_hooks_in.push_back([=]() {
3437 				// Can't use PatchVertices, PrimitiveId, or InvocationId yet; the hooks for those may not have run yet.
3438 				if (msl_options.multi_patch_workgroup)
3439 				{
3440 					// n.b. builtin_invocation_id_id here is the dispatch global invocation ID,
3441 					// not the TC invocation ID.
3442 					statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "* gl_in = &",
3443 					          input_buffer_var_name, "[min(", to_expression(builtin_invocation_id_id), ".x / ",
3444 					          get_entry_point().output_vertices,
3445 					          ", spvIndirectParams[1] - 1) * spvIndirectParams[0]];");
3446 				}
3447 				else
3448 				{
3449 					// It's safe to use InvocationId here because it's directly mapped to a
3450 					// Metal builtin, and therefore doesn't need a hook.
3451 					statement("if (", to_expression(builtin_invocation_id_id), " < spvIndirectParams[0])");
3452 					statement("    ", input_wg_var_name, "[", to_expression(builtin_invocation_id_id),
3453 					          "] = ", ib_var_ref, ";");
3454 					statement("threadgroup_barrier(mem_flags::mem_threadgroup);");
3455 					statement("if (", to_expression(builtin_invocation_id_id),
3456 					          " >= ", get_entry_point().output_vertices, ")");
3457 					statement("    return;");
3458 				}
3459 			});
3460 		}
3461 		break;
3462 
3463 	case StorageClassOutput:
3464 	{
3465 		ib_var_ref = patch ? patch_stage_out_var_name : stage_out_var_name;
3466 
3467 		// Add the output interface struct as a local variable to the entry function.
3468 		// If the entry point should return the output struct, set the entry function
3469 		// to return the output interface struct, otherwise to return nothing.
3470 		// Indicate the output var requires early initialization.
3471 		bool ep_should_return_output = !get_is_rasterization_disabled();
3472 		uint32_t rtn_id = ep_should_return_output ? ib_var_id : 0;
3473 		if (!capture_output_to_buffer)
3474 		{
3475 			entry_func.add_local_variable(ib_var_id);
3476 			for (auto &blk_id : entry_func.blocks)
3477 			{
3478 				auto &blk = get<SPIRBlock>(blk_id);
3479 				if (blk.terminator == SPIRBlock::Return)
3480 					blk.return_value = rtn_id;
3481 			}
3482 			vars_needing_early_declaration.push_back(ib_var_id);
3483 		}
3484 		else
3485 		{
3486 			switch (get_execution_model())
3487 			{
3488 			case ExecutionModelVertex:
3489 			case ExecutionModelTessellationEvaluation:
3490 				// Instead of declaring a struct variable to hold the output and then
3491 				// copying that to the output buffer, we'll declare the output variable
3492 				// as a reference to the final output element in the buffer. Then we can
3493 				// avoid the extra copy.
3494 				entry_func.fixup_hooks_in.push_back([=]() {
3495 					if (stage_out_var_id)
3496 					{
3497 						// The first member of the indirect buffer is always the number of vertices
3498 						// to draw.
3499 						// We zero-base the InstanceID & VertexID variables for HLSL emulation elsewhere, so don't do it twice
3500 						if (get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation)
3501 						{
3502 							statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref,
3503 							          " = ", output_buffer_var_name, "[", to_expression(builtin_invocation_id_id),
3504 							          ".y * ", to_expression(builtin_stage_input_size_id), ".x + ",
3505 							          to_expression(builtin_invocation_id_id), ".x];");
3506 						}
3507 						else if (msl_options.enable_base_index_zero)
3508 						{
3509 							statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref,
3510 							          " = ", output_buffer_var_name, "[", to_expression(builtin_instance_idx_id),
3511 							          " * spvIndirectParams[0] + ", to_expression(builtin_vertex_idx_id), "];");
3512 						}
3513 						else
3514 						{
3515 							statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref,
3516 							          " = ", output_buffer_var_name, "[(", to_expression(builtin_instance_idx_id),
3517 							          " - ", to_expression(builtin_base_instance_id), ") * spvIndirectParams[0] + ",
3518 							          to_expression(builtin_vertex_idx_id), " - ",
3519 							          to_expression(builtin_base_vertex_id), "];");
3520 						}
3521 					}
3522 				});
3523 				break;
3524 			case ExecutionModelTessellationControl:
3525 				if (msl_options.multi_patch_workgroup)
3526 				{
3527 					// We cannot use PrimitiveId here, because the hook may not have run yet.
3528 					if (patch)
3529 					{
3530 						entry_func.fixup_hooks_in.push_back([=]() {
3531 							statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref,
3532 							          " = ", patch_output_buffer_var_name, "[", to_expression(builtin_invocation_id_id),
3533 							          ".x / ", get_entry_point().output_vertices, "];");
3534 						});
3535 					}
3536 					else
3537 					{
3538 						entry_func.fixup_hooks_in.push_back([=]() {
3539 							statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "* gl_out = &",
3540 							          output_buffer_var_name, "[", to_expression(builtin_invocation_id_id), ".x - ",
3541 							          to_expression(builtin_invocation_id_id), ".x % ",
3542 							          get_entry_point().output_vertices, "];");
3543 						});
3544 					}
3545 				}
3546 				else
3547 				{
3548 					if (patch)
3549 					{
3550 						entry_func.fixup_hooks_in.push_back([=]() {
3551 							statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref,
3552 							          " = ", patch_output_buffer_var_name, "[", to_expression(builtin_primitive_id_id),
3553 							          "];");
3554 						});
3555 					}
3556 					else
3557 					{
3558 						entry_func.fixup_hooks_in.push_back([=]() {
3559 							statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "* gl_out = &",
3560 							          output_buffer_var_name, "[", to_expression(builtin_primitive_id_id), " * ",
3561 							          get_entry_point().output_vertices, "];");
3562 						});
3563 					}
3564 				}
3565 				break;
3566 			default:
3567 				break;
3568 			}
3569 		}
3570 		break;
3571 	}
3572 
3573 	default:
3574 		break;
3575 	}
3576 
3577 	set_name(ib_type_id, to_name(ir.default_entry_point) + "_" + ib_var_ref);
3578 	set_name(ib_var_id, ib_var_ref);
3579 
3580 	for (auto *p_var : vars)
3581 	{
3582 		bool strip_array =
3583 		    (get_execution_model() == ExecutionModelTessellationControl ||
3584 		     (get_execution_model() == ExecutionModelTessellationEvaluation && storage == StorageClassInput)) &&
3585 		    !patch;
3586 
3587 		// Fixing up flattened stores in TESC is impossible since the memory is group shared either via
3588 		// device (not masked) or threadgroup (masked) storage classes and it's race condition city.
3589 		meta.strip_array = strip_array;
3590 		meta.allow_local_declaration = !strip_array && !(get_execution_model() == ExecutionModelTessellationControl &&
3591 		                                                 storage == StorageClassOutput);
3592 		add_variable_to_interface_block(storage, ib_var_ref, ib_type, *p_var, meta);
3593 	}
3594 
3595 	if (get_execution_model() == ExecutionModelTessellationControl && msl_options.multi_patch_workgroup &&
3596 	    storage == StorageClassInput)
3597 	{
3598 		// For tessellation control inputs, add all outputs from the vertex shader to ensure
3599 		// the struct containing them is the correct size and layout.
3600 		for (auto &input : inputs_by_location)
3601 		{
3602 			if (location_inputs_in_use.count(input.first) != 0)
3603 				continue;
3604 
3605 			// Create a fake variable to put at the location.
3606 			uint32_t offset = ir.increase_bound_by(4);
3607 			uint32_t type_id = offset;
3608 			uint32_t array_type_id = offset + 1;
3609 			uint32_t ptr_type_id = offset + 2;
3610 			uint32_t var_id = offset + 3;
3611 
3612 			SPIRType type;
3613 			switch (input.second.format)
3614 			{
3615 			case MSL_SHADER_INPUT_FORMAT_UINT16:
3616 			case MSL_SHADER_INPUT_FORMAT_ANY16:
3617 				type.basetype = SPIRType::UShort;
3618 				type.width = 16;
3619 				break;
3620 			case MSL_SHADER_INPUT_FORMAT_ANY32:
3621 			default:
3622 				type.basetype = SPIRType::UInt;
3623 				type.width = 32;
3624 				break;
3625 			}
3626 			type.vecsize = input.second.vecsize;
3627 			set<SPIRType>(type_id, type);
3628 
3629 			type.array.push_back(0);
3630 			type.array_size_literal.push_back(true);
3631 			type.parent_type = type_id;
3632 			set<SPIRType>(array_type_id, type);
3633 
3634 			type.pointer = true;
3635 			type.pointer_depth++;
3636 			type.parent_type = array_type_id;
3637 			type.storage = storage;
3638 			auto &ptr_type = set<SPIRType>(ptr_type_id, type);
3639 			ptr_type.self = array_type_id;
3640 
3641 			auto &fake_var = set<SPIRVariable>(var_id, ptr_type_id, storage);
3642 			set_decoration(var_id, DecorationLocation, input.first);
3643 			meta.strip_array = true;
3644 			meta.allow_local_declaration = false;
3645 			add_variable_to_interface_block(storage, ib_var_ref, ib_type, fake_var, meta);
3646 		}
3647 	}
3648 
3649 	// When multiple variables need to access same location,
3650 	// unroll locations one by one and we will flatten output or input as necessary.
3651 	for (auto &loc : meta.location_meta)
3652 	{
3653 		uint32_t location = loc.first;
3654 		auto &location_meta = loc.second;
3655 
3656 		uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
3657 		uint32_t type_id = build_extended_vector_type(location_meta.base_type_id, location_meta.num_components);
3658 		ib_type.member_types.push_back(type_id);
3659 
3660 		set_member_name(ib_type.self, ib_mbr_idx, join("m_location_", location));
3661 		set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, location);
3662 		mark_location_as_used_by_shader(location, get<SPIRType>(type_id), storage);
3663 
3664 		if (location_meta.flat)
3665 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat);
3666 		if (location_meta.noperspective)
3667 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective);
3668 		if (location_meta.centroid)
3669 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid);
3670 		if (location_meta.sample)
3671 			set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample);
3672 	}
3673 
3674 	// Sort the members of the structure by their locations.
3675 	MemberSorter member_sorter(ib_type, ir.meta[ib_type_id], MemberSorter::LocationThenBuiltInType);
3676 	member_sorter.sort();
3677 
3678 	// The member indices were saved to the original variables, but after the members
3679 	// were sorted, those indices are now likely incorrect. Fix those up now.
3680 	fix_up_interface_member_indices(storage, ib_type_id);
3681 
3682 	// For patch inputs, add one more member, holding the array of control point data.
3683 	if (get_execution_model() == ExecutionModelTessellationEvaluation && storage == StorageClassInput && patch &&
3684 	    stage_in_var_id)
3685 	{
3686 		uint32_t pcp_type_id = ir.increase_bound_by(1);
3687 		auto &pcp_type = set<SPIRType>(pcp_type_id, ib_type);
3688 		pcp_type.basetype = SPIRType::ControlPointArray;
3689 		pcp_type.parent_type = pcp_type.type_alias = get_stage_in_struct_type().self;
3690 		pcp_type.storage = storage;
3691 		ir.meta[pcp_type_id] = ir.meta[ib_type.self];
3692 		uint32_t mbr_idx = uint32_t(ib_type.member_types.size());
3693 		ib_type.member_types.push_back(pcp_type_id);
3694 		set_member_name(ib_type.self, mbr_idx, "gl_in");
3695 	}
3696 
3697 	return ib_var_id;
3698 }
3699 
add_interface_block_pointer(uint32_t ib_var_id,StorageClass storage)3700 uint32_t CompilerMSL::add_interface_block_pointer(uint32_t ib_var_id, StorageClass storage)
3701 {
3702 	if (!ib_var_id)
3703 		return 0;
3704 
3705 	uint32_t ib_ptr_var_id;
3706 	uint32_t next_id = ir.increase_bound_by(3);
3707 	auto &ib_type = expression_type(ib_var_id);
3708 	if (get_execution_model() == ExecutionModelTessellationControl)
3709 	{
3710 		// Tessellation control per-vertex I/O is presented as an array, so we must
3711 		// do the same with our struct here.
3712 		uint32_t ib_ptr_type_id = next_id++;
3713 		auto &ib_ptr_type = set<SPIRType>(ib_ptr_type_id, ib_type);
3714 		ib_ptr_type.parent_type = ib_ptr_type.type_alias = ib_type.self;
3715 		ib_ptr_type.pointer = true;
3716 		ib_ptr_type.pointer_depth++;
3717 		ib_ptr_type.storage =
3718 		    storage == StorageClassInput ?
3719 		        (msl_options.multi_patch_workgroup ? StorageClassStorageBuffer : StorageClassWorkgroup) :
3720 		        StorageClassStorageBuffer;
3721 		ir.meta[ib_ptr_type_id] = ir.meta[ib_type.self];
3722 		// To ensure that get_variable_data_type() doesn't strip off the pointer,
3723 		// which we need, use another pointer.
3724 		uint32_t ib_ptr_ptr_type_id = next_id++;
3725 		auto &ib_ptr_ptr_type = set<SPIRType>(ib_ptr_ptr_type_id, ib_ptr_type);
3726 		ib_ptr_ptr_type.parent_type = ib_ptr_type_id;
3727 		ib_ptr_ptr_type.type_alias = ib_type.self;
3728 		ib_ptr_ptr_type.storage = StorageClassFunction;
3729 		ir.meta[ib_ptr_ptr_type_id] = ir.meta[ib_type.self];
3730 
3731 		ib_ptr_var_id = next_id;
3732 		set<SPIRVariable>(ib_ptr_var_id, ib_ptr_ptr_type_id, StorageClassFunction, 0);
3733 		set_name(ib_ptr_var_id, storage == StorageClassInput ? "gl_in" : "gl_out");
3734 	}
3735 	else
3736 	{
3737 		// Tessellation evaluation per-vertex inputs are also presented as arrays.
3738 		// But, in Metal, this array uses a very special type, 'patch_control_point<T>',
3739 		// which is a container that can be used to access the control point data.
3740 		// To represent this, a special 'ControlPointArray' type has been added to the
3741 		// SPIRV-Cross type system. It should only be generated by and seen in the MSL
3742 		// backend (i.e. this one).
3743 		uint32_t pcp_type_id = next_id++;
3744 		auto &pcp_type = set<SPIRType>(pcp_type_id, ib_type);
3745 		pcp_type.basetype = SPIRType::ControlPointArray;
3746 		pcp_type.parent_type = pcp_type.type_alias = ib_type.self;
3747 		pcp_type.storage = storage;
3748 		ir.meta[pcp_type_id] = ir.meta[ib_type.self];
3749 
3750 		ib_ptr_var_id = next_id;
3751 		set<SPIRVariable>(ib_ptr_var_id, pcp_type_id, storage, 0);
3752 		set_name(ib_ptr_var_id, "gl_in");
3753 		ir.meta[ib_ptr_var_id].decoration.qualified_alias = join(patch_stage_in_var_name, ".gl_in");
3754 	}
3755 	return ib_ptr_var_id;
3756 }
3757 
3758 // Ensure that the type is compatible with the builtin.
3759 // If it is, simply return the given type ID.
3760 // Otherwise, create a new type, and return it's ID.
ensure_correct_builtin_type(uint32_t type_id,BuiltIn builtin)3761 uint32_t CompilerMSL::ensure_correct_builtin_type(uint32_t type_id, BuiltIn builtin)
3762 {
3763 	auto &type = get<SPIRType>(type_id);
3764 
3765 	if ((builtin == BuiltInSampleMask && is_array(type)) ||
3766 	    ((builtin == BuiltInLayer || builtin == BuiltInViewportIndex || builtin == BuiltInFragStencilRefEXT) &&
3767 	     type.basetype != SPIRType::UInt))
3768 	{
3769 		uint32_t next_id = ir.increase_bound_by(type.pointer ? 2 : 1);
3770 		uint32_t base_type_id = next_id++;
3771 		auto &base_type = set<SPIRType>(base_type_id);
3772 		base_type.basetype = SPIRType::UInt;
3773 		base_type.width = 32;
3774 
3775 		if (!type.pointer)
3776 			return base_type_id;
3777 
3778 		uint32_t ptr_type_id = next_id++;
3779 		auto &ptr_type = set<SPIRType>(ptr_type_id);
3780 		ptr_type = base_type;
3781 		ptr_type.pointer = true;
3782 		ptr_type.pointer_depth++;
3783 		ptr_type.storage = type.storage;
3784 		ptr_type.parent_type = base_type_id;
3785 		return ptr_type_id;
3786 	}
3787 
3788 	return type_id;
3789 }
3790 
3791 // Ensure that the type is compatible with the shader input.
3792 // If it is, simply return the given type ID.
3793 // Otherwise, create a new type, and return its ID.
ensure_correct_input_type(uint32_t type_id,uint32_t location,uint32_t num_components,bool strip_array)3794 uint32_t CompilerMSL::ensure_correct_input_type(uint32_t type_id, uint32_t location, uint32_t num_components, bool strip_array)
3795 {
3796 	auto &type = get<SPIRType>(type_id);
3797 
3798 	uint32_t max_array_dimensions = strip_array ? 1 : 0;
3799 
3800 	// Struct and array types must match exactly.
3801 	if (type.basetype == SPIRType::Struct || type.array.size() > max_array_dimensions)
3802 		return type_id;
3803 
3804 	auto p_va = inputs_by_location.find(location);
3805 	if (p_va == end(inputs_by_location))
3806 	{
3807 		if (num_components > type.vecsize)
3808 			return build_extended_vector_type(type_id, num_components);
3809 		else
3810 			return type_id;
3811 	}
3812 
3813 	if (num_components == 0)
3814 		num_components = p_va->second.vecsize;
3815 
3816 	switch (p_va->second.format)
3817 	{
3818 	case MSL_SHADER_INPUT_FORMAT_UINT8:
3819 	{
3820 		switch (type.basetype)
3821 		{
3822 		case SPIRType::UByte:
3823 		case SPIRType::UShort:
3824 		case SPIRType::UInt:
3825 			if (num_components > type.vecsize)
3826 				return build_extended_vector_type(type_id, num_components);
3827 			else
3828 				return type_id;
3829 
3830 		case SPIRType::Short:
3831 			return build_extended_vector_type(type_id, num_components > type.vecsize ? num_components : type.vecsize,
3832 			                                  SPIRType::UShort);
3833 		case SPIRType::Int:
3834 			return build_extended_vector_type(type_id, num_components > type.vecsize ? num_components : type.vecsize,
3835 			                                  SPIRType::UInt);
3836 
3837 		default:
3838 			SPIRV_CROSS_THROW("Vertex attribute type mismatch between host and shader");
3839 		}
3840 	}
3841 
3842 	case MSL_SHADER_INPUT_FORMAT_UINT16:
3843 	{
3844 		switch (type.basetype)
3845 		{
3846 		case SPIRType::UShort:
3847 		case SPIRType::UInt:
3848 			if (num_components > type.vecsize)
3849 				return build_extended_vector_type(type_id, num_components);
3850 			else
3851 				return type_id;
3852 
3853 		case SPIRType::Int:
3854 			return build_extended_vector_type(type_id, num_components > type.vecsize ? num_components : type.vecsize,
3855 			                                  SPIRType::UInt);
3856 
3857 		default:
3858 			SPIRV_CROSS_THROW("Vertex attribute type mismatch between host and shader");
3859 		}
3860 	}
3861 
3862 	default:
3863 		if (num_components > type.vecsize)
3864 			type_id = build_extended_vector_type(type_id, num_components);
3865 		break;
3866 	}
3867 
3868 	return type_id;
3869 }
3870 
mark_struct_members_packed(const SPIRType & type)3871 void CompilerMSL::mark_struct_members_packed(const SPIRType &type)
3872 {
3873 	set_extended_decoration(type.self, SPIRVCrossDecorationPhysicalTypePacked);
3874 
3875 	// Problem case! Struct needs to be placed at an awkward alignment.
3876 	// Mark every member of the child struct as packed.
3877 	uint32_t mbr_cnt = uint32_t(type.member_types.size());
3878 	for (uint32_t i = 0; i < mbr_cnt; i++)
3879 	{
3880 		auto &mbr_type = get<SPIRType>(type.member_types[i]);
3881 		if (mbr_type.basetype == SPIRType::Struct)
3882 		{
3883 			// Recursively mark structs as packed.
3884 			auto *struct_type = &mbr_type;
3885 			while (!struct_type->array.empty())
3886 				struct_type = &get<SPIRType>(struct_type->parent_type);
3887 			mark_struct_members_packed(*struct_type);
3888 		}
3889 		else if (!is_scalar(mbr_type))
3890 			set_extended_member_decoration(type.self, i, SPIRVCrossDecorationPhysicalTypePacked);
3891 	}
3892 }
3893 
mark_scalar_layout_structs(const SPIRType & type)3894 void CompilerMSL::mark_scalar_layout_structs(const SPIRType &type)
3895 {
3896 	uint32_t mbr_cnt = uint32_t(type.member_types.size());
3897 	for (uint32_t i = 0; i < mbr_cnt; i++)
3898 	{
3899 		auto &mbr_type = get<SPIRType>(type.member_types[i]);
3900 		if (mbr_type.basetype == SPIRType::Struct)
3901 		{
3902 			auto *struct_type = &mbr_type;
3903 			while (!struct_type->array.empty())
3904 				struct_type = &get<SPIRType>(struct_type->parent_type);
3905 
3906 			if (has_extended_decoration(struct_type->self, SPIRVCrossDecorationPhysicalTypePacked))
3907 				continue;
3908 
3909 			uint32_t msl_alignment = get_declared_struct_member_alignment_msl(type, i);
3910 			uint32_t msl_size = get_declared_struct_member_size_msl(type, i);
3911 			uint32_t spirv_offset = type_struct_member_offset(type, i);
3912 			uint32_t spirv_offset_next;
3913 			if (i + 1 < mbr_cnt)
3914 				spirv_offset_next = type_struct_member_offset(type, i + 1);
3915 			else
3916 				spirv_offset_next = spirv_offset + msl_size;
3917 
3918 			// Both are complicated cases. In scalar layout, a struct of float3 might just consume 12 bytes,
3919 			// and the next member will be placed at offset 12.
3920 			bool struct_is_misaligned = (spirv_offset % msl_alignment) != 0;
3921 			bool struct_is_too_large = spirv_offset + msl_size > spirv_offset_next;
3922 			uint32_t array_stride = 0;
3923 			bool struct_needs_explicit_padding = false;
3924 
3925 			// Verify that if a struct is used as an array that ArrayStride matches the effective size of the struct.
3926 			if (!mbr_type.array.empty())
3927 			{
3928 				array_stride = type_struct_member_array_stride(type, i);
3929 				uint32_t dimensions = uint32_t(mbr_type.array.size() - 1);
3930 				for (uint32_t dim = 0; dim < dimensions; dim++)
3931 				{
3932 					uint32_t array_size = to_array_size_literal(mbr_type, dim);
3933 					array_stride /= max(array_size, 1u);
3934 				}
3935 
3936 				// Set expected struct size based on ArrayStride.
3937 				struct_needs_explicit_padding = true;
3938 
3939 				// If struct size is larger than array stride, we might be able to fit, if we tightly pack.
3940 				if (get_declared_struct_size_msl(*struct_type) > array_stride)
3941 					struct_is_too_large = true;
3942 			}
3943 
3944 			if (struct_is_misaligned || struct_is_too_large)
3945 				mark_struct_members_packed(*struct_type);
3946 			mark_scalar_layout_structs(*struct_type);
3947 
3948 			if (struct_needs_explicit_padding)
3949 			{
3950 				msl_size = get_declared_struct_size_msl(*struct_type, true, true);
3951 				if (array_stride < msl_size)
3952 				{
3953 					SPIRV_CROSS_THROW("Cannot express an array stride smaller than size of struct type.");
3954 				}
3955 				else
3956 				{
3957 					if (has_extended_decoration(struct_type->self, SPIRVCrossDecorationPaddingTarget))
3958 					{
3959 						if (array_stride !=
3960 						    get_extended_decoration(struct_type->self, SPIRVCrossDecorationPaddingTarget))
3961 							SPIRV_CROSS_THROW(
3962 							    "A struct is used with different array strides. Cannot express this in MSL.");
3963 					}
3964 					else
3965 						set_extended_decoration(struct_type->self, SPIRVCrossDecorationPaddingTarget, array_stride);
3966 				}
3967 			}
3968 		}
3969 	}
3970 }
3971 
3972 // Sort the members of the struct type by offset, and pack and then pad members where needed
3973 // to align MSL members with SPIR-V offsets. The struct members are iterated twice. Packing
3974 // occurs first, followed by padding, because packing a member reduces both its size and its
3975 // natural alignment, possibly requiring a padding member to be added ahead of it.
align_struct(SPIRType & ib_type,unordered_set<uint32_t> & aligned_structs)3976 void CompilerMSL::align_struct(SPIRType &ib_type, unordered_set<uint32_t> &aligned_structs)
3977 {
3978 	// We align structs recursively, so stop any redundant work.
3979 	ID &ib_type_id = ib_type.self;
3980 	if (aligned_structs.count(ib_type_id))
3981 		return;
3982 	aligned_structs.insert(ib_type_id);
3983 
3984 	// Sort the members of the interface structure by their offset.
3985 	// They should already be sorted per SPIR-V spec anyway.
3986 	MemberSorter member_sorter(ib_type, ir.meta[ib_type_id], MemberSorter::Offset);
3987 	member_sorter.sort();
3988 
3989 	auto mbr_cnt = uint32_t(ib_type.member_types.size());
3990 
3991 	for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
3992 	{
3993 		// Pack any dependent struct types before we pack a parent struct.
3994 		auto &mbr_type = get<SPIRType>(ib_type.member_types[mbr_idx]);
3995 		if (mbr_type.basetype == SPIRType::Struct)
3996 			align_struct(mbr_type, aligned_structs);
3997 	}
3998 
3999 	// Test the alignment of each member, and if a member should be closer to the previous
4000 	// member than the default spacing expects, it is likely that the previous member is in
4001 	// a packed format. If so, and the previous member is packable, pack it.
4002 	// For example ... this applies to any 3-element vector that is followed by a scalar.
4003 	uint32_t msl_offset = 0;
4004 	for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
4005 	{
4006 		// This checks the member in isolation, if the member needs some kind of type remapping to conform to SPIR-V
4007 		// offsets, array strides and matrix strides.
4008 		ensure_member_packing_rules_msl(ib_type, mbr_idx);
4009 
4010 		// Align current offset to the current member's default alignment. If the member was packed, it will observe
4011 		// the updated alignment here.
4012 		uint32_t msl_align_mask = get_declared_struct_member_alignment_msl(ib_type, mbr_idx) - 1;
4013 		uint32_t aligned_msl_offset = (msl_offset + msl_align_mask) & ~msl_align_mask;
4014 
4015 		// Fetch the member offset as declared in the SPIRV.
4016 		uint32_t spirv_mbr_offset = get_member_decoration(ib_type_id, mbr_idx, DecorationOffset);
4017 		if (spirv_mbr_offset > aligned_msl_offset)
4018 		{
4019 			// Since MSL and SPIR-V have slightly different struct member alignment and
4020 			// size rules, we'll pad to standard C-packing rules with a char[] array. If the member is farther
4021 			// away than C-packing, expects, add an inert padding member before the the member.
4022 			uint32_t padding_bytes = spirv_mbr_offset - aligned_msl_offset;
4023 			set_extended_member_decoration(ib_type_id, mbr_idx, SPIRVCrossDecorationPaddingTarget, padding_bytes);
4024 
4025 			// Re-align as a sanity check that aligning post-padding matches up.
4026 			msl_offset += padding_bytes;
4027 			aligned_msl_offset = (msl_offset + msl_align_mask) & ~msl_align_mask;
4028 		}
4029 		else if (spirv_mbr_offset < aligned_msl_offset)
4030 		{
4031 			// This should not happen, but deal with unexpected scenarios.
4032 			// It *might* happen if a sub-struct has a larger alignment requirement in MSL than SPIR-V.
4033 			SPIRV_CROSS_THROW("Cannot represent buffer block correctly in MSL.");
4034 		}
4035 
4036 		assert(aligned_msl_offset == spirv_mbr_offset);
4037 
4038 		// Increment the current offset to be positioned immediately after the current member.
4039 		// Don't do this for the last member since it can be unsized, and it is not relevant for padding purposes here.
4040 		if (mbr_idx + 1 < mbr_cnt)
4041 			msl_offset = aligned_msl_offset + get_declared_struct_member_size_msl(ib_type, mbr_idx);
4042 	}
4043 }
4044 
validate_member_packing_rules_msl(const SPIRType & type,uint32_t index) const4045 bool CompilerMSL::validate_member_packing_rules_msl(const SPIRType &type, uint32_t index) const
4046 {
4047 	auto &mbr_type = get<SPIRType>(type.member_types[index]);
4048 	uint32_t spirv_offset = get_member_decoration(type.self, index, DecorationOffset);
4049 
4050 	if (index + 1 < type.member_types.size())
4051 	{
4052 		// First, we will check offsets. If SPIR-V offset + MSL size > SPIR-V offset of next member,
4053 		// we *must* perform some kind of remapping, no way getting around it.
4054 		// We can always pad after this member if necessary, so that case is fine.
4055 		uint32_t spirv_offset_next = get_member_decoration(type.self, index + 1, DecorationOffset);
4056 		assert(spirv_offset_next >= spirv_offset);
4057 		uint32_t maximum_size = spirv_offset_next - spirv_offset;
4058 		uint32_t msl_mbr_size = get_declared_struct_member_size_msl(type, index);
4059 		if (msl_mbr_size > maximum_size)
4060 			return false;
4061 	}
4062 
4063 	if (!mbr_type.array.empty())
4064 	{
4065 		// If we have an array type, array stride must match exactly with SPIR-V.
4066 
4067 		// An exception to this requirement is if we have one array element.
4068 		// This comes from DX scalar layout workaround.
4069 		// If app tries to be cheeky and access the member out of bounds, this will not work, but this is the best we can do.
4070 		// In OpAccessChain with logical memory models, access chains must be in-bounds in SPIR-V specification.
4071 		bool relax_array_stride = mbr_type.array.back() == 1 && mbr_type.array_size_literal.back();
4072 
4073 		if (!relax_array_stride)
4074 		{
4075 			uint32_t spirv_array_stride = type_struct_member_array_stride(type, index);
4076 			uint32_t msl_array_stride = get_declared_struct_member_array_stride_msl(type, index);
4077 			if (spirv_array_stride != msl_array_stride)
4078 				return false;
4079 		}
4080 	}
4081 
4082 	if (is_matrix(mbr_type))
4083 	{
4084 		// Need to check MatrixStride as well.
4085 		uint32_t spirv_matrix_stride = type_struct_member_matrix_stride(type, index);
4086 		uint32_t msl_matrix_stride = get_declared_struct_member_matrix_stride_msl(type, index);
4087 		if (spirv_matrix_stride != msl_matrix_stride)
4088 			return false;
4089 	}
4090 
4091 	// Now, we check alignment.
4092 	uint32_t msl_alignment = get_declared_struct_member_alignment_msl(type, index);
4093 	if ((spirv_offset % msl_alignment) != 0)
4094 		return false;
4095 
4096 	// We're in the clear.
4097 	return true;
4098 }
4099 
4100 // Here we need to verify that the member type we declare conforms to Offset, ArrayStride or MatrixStride restrictions.
4101 // If there is a mismatch, we need to emit remapped types, either normal types, or "packed_X" types.
4102 // In odd cases we need to emit packed and remapped types, for e.g. weird matrices or arrays with weird array strides.
ensure_member_packing_rules_msl(SPIRType & ib_type,uint32_t index)4103 void CompilerMSL::ensure_member_packing_rules_msl(SPIRType &ib_type, uint32_t index)
4104 {
4105 	if (validate_member_packing_rules_msl(ib_type, index))
4106 		return;
4107 
4108 	// We failed validation.
4109 	// This case will be nightmare-ish to deal with. This could possibly happen if struct alignment does not quite
4110 	// match up with what we want. Scalar block layout comes to mind here where we might have to work around the rule
4111 	// that struct alignment == max alignment of all members and struct size depends on this alignment.
4112 	auto &mbr_type = get<SPIRType>(ib_type.member_types[index]);
4113 	if (mbr_type.basetype == SPIRType::Struct)
4114 		SPIRV_CROSS_THROW("Cannot perform any repacking for structs when it is used as a member of another struct.");
4115 
4116 	// Perform remapping here.
4117 	// There is nothing to be gained by using packed scalars, so don't attempt it.
4118 	if (!is_scalar(ib_type))
4119 		set_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypePacked);
4120 
4121 	// Try validating again, now with packed.
4122 	if (validate_member_packing_rules_msl(ib_type, index))
4123 		return;
4124 
4125 	// We're in deep trouble, and we need to create a new PhysicalType which matches up with what we expect.
4126 	// A lot of work goes here ...
4127 	// We will need remapping on Load and Store to translate the types between Logical and Physical.
4128 
4129 	// First, we check if we have small vector std140 array.
4130 	// We detect this if we have an array of vectors, and array stride is greater than number of elements.
4131 	if (!mbr_type.array.empty() && !is_matrix(mbr_type))
4132 	{
4133 		uint32_t array_stride = type_struct_member_array_stride(ib_type, index);
4134 
4135 		// Hack off array-of-arrays until we find the array stride per element we must have to make it work.
4136 		uint32_t dimensions = uint32_t(mbr_type.array.size() - 1);
4137 		for (uint32_t dim = 0; dim < dimensions; dim++)
4138 			array_stride /= max(to_array_size_literal(mbr_type, dim), 1u);
4139 
4140 		uint32_t elems_per_stride = array_stride / (mbr_type.width / 8);
4141 
4142 		if (elems_per_stride == 3)
4143 			SPIRV_CROSS_THROW("Cannot use ArrayStride of 3 elements in remapping scenarios.");
4144 		else if (elems_per_stride > 4)
4145 			SPIRV_CROSS_THROW("Cannot represent vectors with more than 4 elements in MSL.");
4146 
4147 		auto physical_type = mbr_type;
4148 		physical_type.vecsize = elems_per_stride;
4149 		physical_type.parent_type = 0;
4150 		uint32_t type_id = ir.increase_bound_by(1);
4151 		set<SPIRType>(type_id, physical_type);
4152 		set_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypeID, type_id);
4153 		set_decoration(type_id, DecorationArrayStride, array_stride);
4154 
4155 		// Remove packed_ for vectors of size 1, 2 and 4.
4156 		if (has_extended_decoration(ib_type.self, SPIRVCrossDecorationPhysicalTypePacked))
4157 			SPIRV_CROSS_THROW("Unable to remove packed decoration as entire struct must be fully packed. Do not mix "
4158 			                  "scalar and std140 layout rules.");
4159 		else
4160 			unset_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypePacked);
4161 	}
4162 	else if (is_matrix(mbr_type))
4163 	{
4164 		// MatrixStride might be std140-esque.
4165 		uint32_t matrix_stride = type_struct_member_matrix_stride(ib_type, index);
4166 
4167 		uint32_t elems_per_stride = matrix_stride / (mbr_type.width / 8);
4168 
4169 		if (elems_per_stride == 3)
4170 			SPIRV_CROSS_THROW("Cannot use ArrayStride of 3 elements in remapping scenarios.");
4171 		else if (elems_per_stride > 4)
4172 			SPIRV_CROSS_THROW("Cannot represent vectors with more than 4 elements in MSL.");
4173 
4174 		bool row_major = has_member_decoration(ib_type.self, index, DecorationRowMajor);
4175 
4176 		auto physical_type = mbr_type;
4177 		physical_type.parent_type = 0;
4178 		if (row_major)
4179 			physical_type.columns = elems_per_stride;
4180 		else
4181 			physical_type.vecsize = elems_per_stride;
4182 		uint32_t type_id = ir.increase_bound_by(1);
4183 		set<SPIRType>(type_id, physical_type);
4184 		set_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypeID, type_id);
4185 
4186 		// Remove packed_ for vectors of size 1, 2 and 4.
4187 		if (has_extended_decoration(ib_type.self, SPIRVCrossDecorationPhysicalTypePacked))
4188 			SPIRV_CROSS_THROW("Unable to remove packed decoration as entire struct must be fully packed. Do not mix "
4189 			                  "scalar and std140 layout rules.");
4190 		else
4191 			unset_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypePacked);
4192 	}
4193 	else
4194 		SPIRV_CROSS_THROW("Found a buffer packing case which we cannot represent in MSL.");
4195 
4196 	// Try validating again, now with physical type remapping.
4197 	if (validate_member_packing_rules_msl(ib_type, index))
4198 		return;
4199 
4200 	// We might have a particular odd scalar layout case where the last element of an array
4201 	// does not take up as much space as the ArrayStride or MatrixStride. This can happen with DX cbuffers.
4202 	// The "proper" workaround for this is extremely painful and essentially impossible in the edge case of float3[],
4203 	// so we hack around it by declaring the offending array or matrix with one less array size/col/row,
4204 	// and rely on padding to get the correct value. We will technically access arrays out of bounds into the padding region,
4205 	// but it should spill over gracefully without too much trouble. We rely on behavior like this for unsized arrays anyways.
4206 
4207 	// E.g. we might observe a physical layout of:
4208 	// { float2 a[2]; float b; } in cbuffer layout where ArrayStride of a is 16, but offset of b is 24, packed right after a[1] ...
4209 	uint32_t type_id = get_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypeID);
4210 	auto &type = get<SPIRType>(type_id);
4211 
4212 	// Modify the physical type in-place. This is safe since each physical type workaround is a copy.
4213 	if (is_array(type))
4214 	{
4215 		if (type.array.back() > 1)
4216 		{
4217 			if (!type.array_size_literal.back())
4218 				SPIRV_CROSS_THROW("Cannot apply scalar layout workaround with spec constant array size.");
4219 			type.array.back() -= 1;
4220 		}
4221 		else
4222 		{
4223 			// We have an array of size 1, so we cannot decrement that. Our only option now is to
4224 			// force a packed layout instead, and drop the physical type remap since ArrayStride is meaningless now.
4225 			unset_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypeID);
4226 			set_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypePacked);
4227 		}
4228 	}
4229 	else if (is_matrix(type))
4230 	{
4231 		bool row_major = has_member_decoration(ib_type.self, index, DecorationRowMajor);
4232 		if (!row_major)
4233 		{
4234 			// Slice off one column. If we only have 2 columns, this might turn the matrix into a vector with one array element instead.
4235 			if (type.columns > 2)
4236 			{
4237 				type.columns--;
4238 			}
4239 			else if (type.columns == 2)
4240 			{
4241 				type.columns = 1;
4242 				assert(type.array.empty());
4243 				type.array.push_back(1);
4244 				type.array_size_literal.push_back(true);
4245 			}
4246 		}
4247 		else
4248 		{
4249 			// Slice off one row. If we only have 2 rows, this might turn the matrix into a vector with one array element instead.
4250 			if (type.vecsize > 2)
4251 			{
4252 				type.vecsize--;
4253 			}
4254 			else if (type.vecsize == 2)
4255 			{
4256 				type.vecsize = type.columns;
4257 				type.columns = 1;
4258 				assert(type.array.empty());
4259 				type.array.push_back(1);
4260 				type.array_size_literal.push_back(true);
4261 			}
4262 		}
4263 	}
4264 
4265 	// This better validate now, or we must fail gracefully.
4266 	if (!validate_member_packing_rules_msl(ib_type, index))
4267 		SPIRV_CROSS_THROW("Found a buffer packing case which we cannot represent in MSL.");
4268 }
4269 
emit_store_statement(uint32_t lhs_expression,uint32_t rhs_expression)4270 void CompilerMSL::emit_store_statement(uint32_t lhs_expression, uint32_t rhs_expression)
4271 {
4272 	auto &type = expression_type(rhs_expression);
4273 
4274 	bool lhs_remapped_type = has_extended_decoration(lhs_expression, SPIRVCrossDecorationPhysicalTypeID);
4275 	bool lhs_packed_type = has_extended_decoration(lhs_expression, SPIRVCrossDecorationPhysicalTypePacked);
4276 	auto *lhs_e = maybe_get<SPIRExpression>(lhs_expression);
4277 	auto *rhs_e = maybe_get<SPIRExpression>(rhs_expression);
4278 
4279 	bool transpose = lhs_e && lhs_e->need_transpose;
4280 
4281 	// No physical type remapping, and no packed type, so can just emit a store directly.
4282 	if (!lhs_remapped_type && !lhs_packed_type)
4283 	{
4284 		// We might not be dealing with remapped physical types or packed types,
4285 		// but we might be doing a clean store to a row-major matrix.
4286 		// In this case, we just flip transpose states, and emit the store, a transpose must be in the RHS expression, if any.
4287 		if (is_matrix(type) && lhs_e && lhs_e->need_transpose)
4288 		{
4289 			lhs_e->need_transpose = false;
4290 
4291 			if (rhs_e && rhs_e->need_transpose)
4292 			{
4293 				// Direct copy, but might need to unpack RHS.
4294 				// Skip the transpose, as we will transpose when writing to LHS and transpose(transpose(T)) == T.
4295 				rhs_e->need_transpose = false;
4296 				statement(to_expression(lhs_expression), " = ", to_unpacked_row_major_matrix_expression(rhs_expression),
4297 				          ";");
4298 				rhs_e->need_transpose = true;
4299 			}
4300 			else
4301 				statement(to_expression(lhs_expression), " = transpose(", to_unpacked_expression(rhs_expression), ");");
4302 
4303 			lhs_e->need_transpose = true;
4304 			register_write(lhs_expression);
4305 		}
4306 		else if (lhs_e && lhs_e->need_transpose)
4307 		{
4308 			lhs_e->need_transpose = false;
4309 
4310 			// Storing a column to a row-major matrix. Unroll the write.
4311 			for (uint32_t c = 0; c < type.vecsize; c++)
4312 			{
4313 				auto lhs_expr = to_dereferenced_expression(lhs_expression);
4314 				auto column_index = lhs_expr.find_last_of('[');
4315 				if (column_index != string::npos)
4316 				{
4317 					statement(lhs_expr.insert(column_index, join('[', c, ']')), " = ",
4318 					          to_extract_component_expression(rhs_expression, c), ";");
4319 				}
4320 			}
4321 			lhs_e->need_transpose = true;
4322 			register_write(lhs_expression);
4323 		}
4324 		else
4325 			CompilerGLSL::emit_store_statement(lhs_expression, rhs_expression);
4326 	}
4327 	else if (!lhs_remapped_type && !is_matrix(type) && !transpose)
4328 	{
4329 		// Even if the target type is packed, we can directly store to it. We cannot store to packed matrices directly,
4330 		// since they are declared as array of vectors instead, and we need the fallback path below.
4331 		CompilerGLSL::emit_store_statement(lhs_expression, rhs_expression);
4332 	}
4333 	else
4334 	{
4335 		// Special handling when storing to a remapped physical type.
4336 		// This is mostly to deal with std140 padded matrices or vectors.
4337 
4338 		TypeID physical_type_id = lhs_remapped_type ?
4339 		                              ID(get_extended_decoration(lhs_expression, SPIRVCrossDecorationPhysicalTypeID)) :
4340 		                              type.self;
4341 
4342 		auto &physical_type = get<SPIRType>(physical_type_id);
4343 
4344 		if (is_matrix(type))
4345 		{
4346 			const char *packed_pfx = lhs_packed_type ? "packed_" : "";
4347 
4348 			// Packed matrices are stored as arrays of packed vectors, so we need
4349 			// to assign the vectors one at a time.
4350 			// For row-major matrices, we need to transpose the *right-hand* side,
4351 			// not the left-hand side.
4352 
4353 			// Lots of cases to cover here ...
4354 
4355 			bool rhs_transpose = rhs_e && rhs_e->need_transpose;
4356 			SPIRType write_type = type;
4357 			string cast_expr;
4358 
4359 			// We're dealing with transpose manually.
4360 			if (rhs_transpose)
4361 				rhs_e->need_transpose = false;
4362 
4363 			if (transpose)
4364 			{
4365 				// We're dealing with transpose manually.
4366 				lhs_e->need_transpose = false;
4367 				write_type.vecsize = type.columns;
4368 				write_type.columns = 1;
4369 
4370 				if (physical_type.columns != type.columns)
4371 					cast_expr = join("(device ", packed_pfx, type_to_glsl(write_type), "&)");
4372 
4373 				if (rhs_transpose)
4374 				{
4375 					// If RHS is also transposed, we can just copy row by row.
4376 					for (uint32_t i = 0; i < type.vecsize; i++)
4377 					{
4378 						statement(cast_expr, to_enclosed_expression(lhs_expression), "[", i, "]", " = ",
4379 						          to_unpacked_row_major_matrix_expression(rhs_expression), "[", i, "];");
4380 					}
4381 				}
4382 				else
4383 				{
4384 					auto vector_type = expression_type(rhs_expression);
4385 					vector_type.vecsize = vector_type.columns;
4386 					vector_type.columns = 1;
4387 
4388 					// Transpose on the fly. Emitting a lot of full transpose() ops and extracting lanes seems very bad,
4389 					// so pick out individual components instead.
4390 					for (uint32_t i = 0; i < type.vecsize; i++)
4391 					{
4392 						string rhs_row = type_to_glsl_constructor(vector_type) + "(";
4393 						for (uint32_t j = 0; j < vector_type.vecsize; j++)
4394 						{
4395 							rhs_row += join(to_enclosed_unpacked_expression(rhs_expression), "[", j, "][", i, "]");
4396 							if (j + 1 < vector_type.vecsize)
4397 								rhs_row += ", ";
4398 						}
4399 						rhs_row += ")";
4400 
4401 						statement(cast_expr, to_enclosed_expression(lhs_expression), "[", i, "]", " = ", rhs_row, ";");
4402 					}
4403 				}
4404 
4405 				// We're dealing with transpose manually.
4406 				lhs_e->need_transpose = true;
4407 			}
4408 			else
4409 			{
4410 				write_type.columns = 1;
4411 
4412 				if (physical_type.vecsize != type.vecsize)
4413 					cast_expr = join("(device ", packed_pfx, type_to_glsl(write_type), "&)");
4414 
4415 				if (rhs_transpose)
4416 				{
4417 					auto vector_type = expression_type(rhs_expression);
4418 					vector_type.columns = 1;
4419 
4420 					// Transpose on the fly. Emitting a lot of full transpose() ops and extracting lanes seems very bad,
4421 					// so pick out individual components instead.
4422 					for (uint32_t i = 0; i < type.columns; i++)
4423 					{
4424 						string rhs_row = type_to_glsl_constructor(vector_type) + "(";
4425 						for (uint32_t j = 0; j < vector_type.vecsize; j++)
4426 						{
4427 							// Need to explicitly unpack expression since we've mucked with transpose state.
4428 							auto unpacked_expr = to_unpacked_row_major_matrix_expression(rhs_expression);
4429 							rhs_row += join(unpacked_expr, "[", j, "][", i, "]");
4430 							if (j + 1 < vector_type.vecsize)
4431 								rhs_row += ", ";
4432 						}
4433 						rhs_row += ")";
4434 
4435 						statement(cast_expr, to_enclosed_expression(lhs_expression), "[", i, "]", " = ", rhs_row, ";");
4436 					}
4437 				}
4438 				else
4439 				{
4440 					// Copy column-by-column.
4441 					for (uint32_t i = 0; i < type.columns; i++)
4442 					{
4443 						statement(cast_expr, to_enclosed_expression(lhs_expression), "[", i, "]", " = ",
4444 						          to_enclosed_unpacked_expression(rhs_expression), "[", i, "];");
4445 					}
4446 				}
4447 			}
4448 
4449 			// We're dealing with transpose manually.
4450 			if (rhs_transpose)
4451 				rhs_e->need_transpose = true;
4452 		}
4453 		else if (transpose)
4454 		{
4455 			lhs_e->need_transpose = false;
4456 
4457 			SPIRType write_type = type;
4458 			write_type.vecsize = 1;
4459 			write_type.columns = 1;
4460 
4461 			// Storing a column to a row-major matrix. Unroll the write.
4462 			for (uint32_t c = 0; c < type.vecsize; c++)
4463 			{
4464 				auto lhs_expr = to_enclosed_expression(lhs_expression);
4465 				auto column_index = lhs_expr.find_last_of('[');
4466 				if (column_index != string::npos)
4467 				{
4468 					statement("((device ", type_to_glsl(write_type), "*)&",
4469 					          lhs_expr.insert(column_index, join('[', c, ']', ")")), " = ",
4470 					          to_extract_component_expression(rhs_expression, c), ";");
4471 				}
4472 			}
4473 
4474 			lhs_e->need_transpose = true;
4475 		}
4476 		else if ((is_matrix(physical_type) || is_array(physical_type)) && physical_type.vecsize > type.vecsize)
4477 		{
4478 			assert(type.vecsize >= 1 && type.vecsize <= 3);
4479 
4480 			// If we have packed types, we cannot use swizzled stores.
4481 			// We could technically unroll the store for each element if needed.
4482 			// When remapping to a std140 physical type, we always get float4,
4483 			// and the packed decoration should always be removed.
4484 			assert(!lhs_packed_type);
4485 
4486 			string lhs = to_dereferenced_expression(lhs_expression);
4487 			string rhs = to_pointer_expression(rhs_expression);
4488 
4489 			// Unpack the expression so we can store to it with a float or float2.
4490 			// It's still an l-value, so it's fine. Most other unpacking of expressions turn them into r-values instead.
4491 			lhs = join("(device ", type_to_glsl(type), "&)", enclose_expression(lhs));
4492 			if (!optimize_read_modify_write(expression_type(rhs_expression), lhs, rhs))
4493 				statement(lhs, " = ", rhs, ";");
4494 		}
4495 		else if (!is_matrix(type))
4496 		{
4497 			string lhs = to_dereferenced_expression(lhs_expression);
4498 			string rhs = to_pointer_expression(rhs_expression);
4499 			if (!optimize_read_modify_write(expression_type(rhs_expression), lhs, rhs))
4500 				statement(lhs, " = ", rhs, ";");
4501 		}
4502 
4503 		register_write(lhs_expression);
4504 	}
4505 }
4506 
expression_ends_with(const string & expr_str,const std::string & ending)4507 static bool expression_ends_with(const string &expr_str, const std::string &ending)
4508 {
4509 	if (expr_str.length() >= ending.length())
4510 		return (expr_str.compare(expr_str.length() - ending.length(), ending.length(), ending) == 0);
4511 	else
4512 		return false;
4513 }
4514 
4515 // Converts the format of the current expression from packed to unpacked,
4516 // by wrapping the expression in a constructor of the appropriate type.
4517 // Also, handle special physical ID remapping scenarios, similar to emit_store_statement().
unpack_expression_type(string expr_str,const SPIRType & type,uint32_t physical_type_id,bool packed,bool row_major)4518 string CompilerMSL::unpack_expression_type(string expr_str, const SPIRType &type, uint32_t physical_type_id,
4519                                            bool packed, bool row_major)
4520 {
4521 	// Trivial case, nothing to do.
4522 	if (physical_type_id == 0 && !packed)
4523 		return expr_str;
4524 
4525 	const SPIRType *physical_type = nullptr;
4526 	if (physical_type_id)
4527 		physical_type = &get<SPIRType>(physical_type_id);
4528 
4529 	static const char *swizzle_lut[] = {
4530 		".x",
4531 		".xy",
4532 		".xyz",
4533 	};
4534 
4535 	if (physical_type && is_vector(*physical_type) && is_array(*physical_type) &&
4536 	    physical_type->vecsize > type.vecsize && !expression_ends_with(expr_str, swizzle_lut[type.vecsize - 1]))
4537 	{
4538 		// std140 array cases for vectors.
4539 		assert(type.vecsize >= 1 && type.vecsize <= 3);
4540 		return enclose_expression(expr_str) + swizzle_lut[type.vecsize - 1];
4541 	}
4542 	else if (physical_type && is_matrix(*physical_type) && is_vector(type) && physical_type->vecsize > type.vecsize)
4543 	{
4544 		// Extract column from padded matrix.
4545 		assert(type.vecsize >= 1 && type.vecsize <= 3);
4546 		return enclose_expression(expr_str) + swizzle_lut[type.vecsize - 1];
4547 	}
4548 	else if (is_matrix(type))
4549 	{
4550 		// Packed matrices are stored as arrays of packed vectors. Unfortunately,
4551 		// we can't just pass the array straight to the matrix constructor. We have to
4552 		// pass each vector individually, so that they can be unpacked to normal vectors.
4553 		if (!physical_type)
4554 			physical_type = &type;
4555 
4556 		uint32_t vecsize = type.vecsize;
4557 		uint32_t columns = type.columns;
4558 		if (row_major)
4559 			swap(vecsize, columns);
4560 
4561 		uint32_t physical_vecsize = row_major ? physical_type->columns : physical_type->vecsize;
4562 
4563 		const char *base_type = type.width == 16 ? "half" : "float";
4564 		string unpack_expr = join(base_type, columns, "x", vecsize, "(");
4565 
4566 		const char *load_swiz = "";
4567 
4568 		if (physical_vecsize != vecsize)
4569 			load_swiz = swizzle_lut[vecsize - 1];
4570 
4571 		for (uint32_t i = 0; i < columns; i++)
4572 		{
4573 			if (i > 0)
4574 				unpack_expr += ", ";
4575 
4576 			if (packed)
4577 				unpack_expr += join(base_type, physical_vecsize, "(", expr_str, "[", i, "]", ")", load_swiz);
4578 			else
4579 				unpack_expr += join(expr_str, "[", i, "]", load_swiz);
4580 		}
4581 
4582 		unpack_expr += ")";
4583 		return unpack_expr;
4584 	}
4585 	else
4586 	{
4587 		return join(type_to_glsl(type), "(", expr_str, ")");
4588 	}
4589 }
4590 
4591 // Emits the file header info
emit_header()4592 void CompilerMSL::emit_header()
4593 {
4594 	// This particular line can be overridden during compilation, so make it a flag and not a pragma line.
4595 	if (suppress_missing_prototypes)
4596 		statement("#pragma clang diagnostic ignored \"-Wmissing-prototypes\"");
4597 
4598 	// Disable warning about missing braces for array<T> template to make arrays a value type
4599 	if (spv_function_implementations.count(SPVFuncImplUnsafeArray) != 0)
4600 		statement("#pragma clang diagnostic ignored \"-Wmissing-braces\"");
4601 
4602 	for (auto &pragma : pragma_lines)
4603 		statement(pragma);
4604 
4605 	if (!pragma_lines.empty() || suppress_missing_prototypes)
4606 		statement("");
4607 
4608 	statement("#include <metal_stdlib>");
4609 	statement("#include <simd/simd.h>");
4610 
4611 	for (auto &header : header_lines)
4612 		statement(header);
4613 
4614 	statement("");
4615 	statement("using namespace metal;");
4616 	statement("");
4617 
4618 	for (auto &td : typedef_lines)
4619 		statement(td);
4620 
4621 	if (!typedef_lines.empty())
4622 		statement("");
4623 }
4624 
add_pragma_line(const string & line)4625 void CompilerMSL::add_pragma_line(const string &line)
4626 {
4627 	auto rslt = pragma_lines.insert(line);
4628 	if (rslt.second)
4629 		force_recompile();
4630 }
4631 
add_typedef_line(const string & line)4632 void CompilerMSL::add_typedef_line(const string &line)
4633 {
4634 	auto rslt = typedef_lines.insert(line);
4635 	if (rslt.second)
4636 		force_recompile();
4637 }
4638 
4639 // Template struct like spvUnsafeArray<> need to be declared *before* any resources are declared
emit_custom_templates()4640 void CompilerMSL::emit_custom_templates()
4641 {
4642 	for (const auto &spv_func : spv_function_implementations)
4643 	{
4644 		switch (spv_func)
4645 		{
4646 		case SPVFuncImplUnsafeArray:
4647 			statement("template<typename T, size_t Num>");
4648 			statement("struct spvUnsafeArray");
4649 			begin_scope();
4650 			statement("T elements[Num ? Num : 1];");
4651 			statement("");
4652 			statement("thread T& operator [] (size_t pos) thread");
4653 			begin_scope();
4654 			statement("return elements[pos];");
4655 			end_scope();
4656 			statement("constexpr const thread T& operator [] (size_t pos) const thread");
4657 			begin_scope();
4658 			statement("return elements[pos];");
4659 			end_scope();
4660 			statement("");
4661 			statement("device T& operator [] (size_t pos) device");
4662 			begin_scope();
4663 			statement("return elements[pos];");
4664 			end_scope();
4665 			statement("constexpr const device T& operator [] (size_t pos) const device");
4666 			begin_scope();
4667 			statement("return elements[pos];");
4668 			end_scope();
4669 			statement("");
4670 			statement("constexpr const constant T& operator [] (size_t pos) const constant");
4671 			begin_scope();
4672 			statement("return elements[pos];");
4673 			end_scope();
4674 			statement("");
4675 			statement("threadgroup T& operator [] (size_t pos) threadgroup");
4676 			begin_scope();
4677 			statement("return elements[pos];");
4678 			end_scope();
4679 			statement("constexpr const threadgroup T& operator [] (size_t pos) const threadgroup");
4680 			begin_scope();
4681 			statement("return elements[pos];");
4682 			end_scope();
4683 			end_scope_decl();
4684 			statement("");
4685 			break;
4686 
4687 		default:
4688 			break;
4689 		}
4690 	}
4691 }
4692 
4693 // Emits any needed custom function bodies.
4694 // Metal helper functions must be static force-inline, i.e. static inline __attribute__((always_inline))
4695 // otherwise they will cause problems when linked together in a single Metallib.
emit_custom_functions()4696 void CompilerMSL::emit_custom_functions()
4697 {
4698 	for (uint32_t i = kArrayCopyMultidimMax; i >= 2; i--)
4699 		if (spv_function_implementations.count(static_cast<SPVFuncImpl>(SPVFuncImplArrayCopyMultidimBase + i)))
4700 			spv_function_implementations.insert(static_cast<SPVFuncImpl>(SPVFuncImplArrayCopyMultidimBase + i - 1));
4701 
4702 	if (spv_function_implementations.count(SPVFuncImplDynamicImageSampler))
4703 	{
4704 		// Unfortunately, this one needs a lot of the other functions to compile OK.
4705 		if (!msl_options.supports_msl_version(2))
4706 			SPIRV_CROSS_THROW(
4707 			    "spvDynamicImageSampler requires default-constructible texture objects, which require MSL 2.0.");
4708 		spv_function_implementations.insert(SPVFuncImplForwardArgs);
4709 		spv_function_implementations.insert(SPVFuncImplTextureSwizzle);
4710 		if (msl_options.swizzle_texture_samples)
4711 			spv_function_implementations.insert(SPVFuncImplGatherSwizzle);
4712 		for (uint32_t i = SPVFuncImplChromaReconstructNearest2Plane;
4713 		     i <= SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint3Plane; i++)
4714 			spv_function_implementations.insert(static_cast<SPVFuncImpl>(i));
4715 		spv_function_implementations.insert(SPVFuncImplExpandITUFullRange);
4716 		spv_function_implementations.insert(SPVFuncImplExpandITUNarrowRange);
4717 		spv_function_implementations.insert(SPVFuncImplConvertYCbCrBT709);
4718 		spv_function_implementations.insert(SPVFuncImplConvertYCbCrBT601);
4719 		spv_function_implementations.insert(SPVFuncImplConvertYCbCrBT2020);
4720 	}
4721 
4722 	for (uint32_t i = SPVFuncImplChromaReconstructNearest2Plane;
4723 	     i <= SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint3Plane; i++)
4724 		if (spv_function_implementations.count(static_cast<SPVFuncImpl>(i)))
4725 			spv_function_implementations.insert(SPVFuncImplForwardArgs);
4726 
4727 	if (spv_function_implementations.count(SPVFuncImplTextureSwizzle) ||
4728 	    spv_function_implementations.count(SPVFuncImplGatherSwizzle) ||
4729 	    spv_function_implementations.count(SPVFuncImplGatherCompareSwizzle))
4730 	{
4731 		spv_function_implementations.insert(SPVFuncImplForwardArgs);
4732 		spv_function_implementations.insert(SPVFuncImplGetSwizzle);
4733 	}
4734 
4735 	for (const auto &spv_func : spv_function_implementations)
4736 	{
4737 		switch (spv_func)
4738 		{
4739 		case SPVFuncImplMod:
4740 			statement("// Implementation of the GLSL mod() function, which is slightly different than Metal fmod()");
4741 			statement("template<typename Tx, typename Ty>");
4742 			statement("inline Tx mod(Tx x, Ty y)");
4743 			begin_scope();
4744 			statement("return x - y * floor(x / y);");
4745 			end_scope();
4746 			statement("");
4747 			break;
4748 
4749 		case SPVFuncImplRadians:
4750 			statement("// Implementation of the GLSL radians() function");
4751 			statement("template<typename T>");
4752 			statement("inline T radians(T d)");
4753 			begin_scope();
4754 			statement("return d * T(0.01745329251);");
4755 			end_scope();
4756 			statement("");
4757 			break;
4758 
4759 		case SPVFuncImplDegrees:
4760 			statement("// Implementation of the GLSL degrees() function");
4761 			statement("template<typename T>");
4762 			statement("inline T degrees(T r)");
4763 			begin_scope();
4764 			statement("return r * T(57.2957795131);");
4765 			end_scope();
4766 			statement("");
4767 			break;
4768 
4769 		case SPVFuncImplFindILsb:
4770 			statement("// Implementation of the GLSL findLSB() function");
4771 			statement("template<typename T>");
4772 			statement("inline T spvFindLSB(T x)");
4773 			begin_scope();
4774 			statement("return select(ctz(x), T(-1), x == T(0));");
4775 			end_scope();
4776 			statement("");
4777 			break;
4778 
4779 		case SPVFuncImplFindUMsb:
4780 			statement("// Implementation of the unsigned GLSL findMSB() function");
4781 			statement("template<typename T>");
4782 			statement("inline T spvFindUMSB(T x)");
4783 			begin_scope();
4784 			statement("return select(clz(T(0)) - (clz(x) + T(1)), T(-1), x == T(0));");
4785 			end_scope();
4786 			statement("");
4787 			break;
4788 
4789 		case SPVFuncImplFindSMsb:
4790 			statement("// Implementation of the signed GLSL findMSB() function");
4791 			statement("template<typename T>");
4792 			statement("inline T spvFindSMSB(T x)");
4793 			begin_scope();
4794 			statement("T v = select(x, T(-1) - x, x < T(0));");
4795 			statement("return select(clz(T(0)) - (clz(v) + T(1)), T(-1), v == T(0));");
4796 			end_scope();
4797 			statement("");
4798 			break;
4799 
4800 		case SPVFuncImplSSign:
4801 			statement("// Implementation of the GLSL sign() function for integer types");
4802 			statement("template<typename T, typename E = typename enable_if<is_integral<T>::value>::type>");
4803 			statement("inline T sign(T x)");
4804 			begin_scope();
4805 			statement("return select(select(select(x, T(0), x == T(0)), T(1), x > T(0)), T(-1), x < T(0));");
4806 			end_scope();
4807 			statement("");
4808 			break;
4809 
4810 		case SPVFuncImplArrayCopy:
4811 		case SPVFuncImplArrayOfArrayCopy2Dim:
4812 		case SPVFuncImplArrayOfArrayCopy3Dim:
4813 		case SPVFuncImplArrayOfArrayCopy4Dim:
4814 		case SPVFuncImplArrayOfArrayCopy5Dim:
4815 		case SPVFuncImplArrayOfArrayCopy6Dim:
4816 		{
4817 			// Unfortunately we cannot template on the address space, so combinatorial explosion it is.
4818 			static const char *function_name_tags[] = {
4819 				"FromConstantToStack",     "FromConstantToThreadGroup", "FromStackToStack",
4820 				"FromStackToThreadGroup",  "FromThreadGroupToStack",    "FromThreadGroupToThreadGroup",
4821 				"FromDeviceToDevice",      "FromConstantToDevice",      "FromStackToDevice",
4822 				"FromThreadGroupToDevice", "FromDeviceToStack",         "FromDeviceToThreadGroup",
4823 			};
4824 
4825 			static const char *src_address_space[] = {
4826 				"constant",          "constant",          "thread const", "thread const",
4827 				"threadgroup const", "threadgroup const", "device const", "constant",
4828 				"thread const",      "threadgroup const", "device const", "device const",
4829 			};
4830 
4831 			static const char *dst_address_space[] = {
4832 				"thread", "threadgroup", "thread", "threadgroup", "thread", "threadgroup",
4833 				"device", "device",      "device", "device",      "thread", "threadgroup",
4834 			};
4835 
4836 			for (uint32_t variant = 0; variant < 12; variant++)
4837 			{
4838 				uint32_t dimensions = spv_func - SPVFuncImplArrayCopyMultidimBase;
4839 				string tmp = "template<typename T";
4840 				for (uint8_t i = 0; i < dimensions; i++)
4841 				{
4842 					tmp += ", uint ";
4843 					tmp += 'A' + i;
4844 				}
4845 				tmp += ">";
4846 				statement(tmp);
4847 
4848 				string array_arg;
4849 				for (uint8_t i = 0; i < dimensions; i++)
4850 				{
4851 					array_arg += "[";
4852 					array_arg += 'A' + i;
4853 					array_arg += "]";
4854 				}
4855 
4856 				statement("inline void spvArrayCopy", function_name_tags[variant], dimensions, "(",
4857 				          dst_address_space[variant], " T (&dst)", array_arg, ", ", src_address_space[variant],
4858 				          " T (&src)", array_arg, ")");
4859 
4860 				begin_scope();
4861 				statement("for (uint i = 0; i < A; i++)");
4862 				begin_scope();
4863 
4864 				if (dimensions == 1)
4865 					statement("dst[i] = src[i];");
4866 				else
4867 					statement("spvArrayCopy", function_name_tags[variant], dimensions - 1, "(dst[i], src[i]);");
4868 				end_scope();
4869 				end_scope();
4870 				statement("");
4871 			}
4872 			break;
4873 		}
4874 
4875 		// Support for Metal 2.1's new texture_buffer type.
4876 		case SPVFuncImplTexelBufferCoords:
4877 		{
4878 			if (msl_options.texel_buffer_texture_width > 0)
4879 			{
4880 				string tex_width_str = convert_to_string(msl_options.texel_buffer_texture_width);
4881 				statement("// Returns 2D texture coords corresponding to 1D texel buffer coords");
4882 				statement(force_inline);
4883 				statement("uint2 spvTexelBufferCoord(uint tc)");
4884 				begin_scope();
4885 				statement(join("return uint2(tc % ", tex_width_str, ", tc / ", tex_width_str, ");"));
4886 				end_scope();
4887 				statement("");
4888 			}
4889 			else
4890 			{
4891 				statement("// Returns 2D texture coords corresponding to 1D texel buffer coords");
4892 				statement(
4893 				    "#define spvTexelBufferCoord(tc, tex) uint2((tc) % (tex).get_width(), (tc) / (tex).get_width())");
4894 				statement("");
4895 			}
4896 			break;
4897 		}
4898 
4899 		// Emulate texture2D atomic operations
4900 		case SPVFuncImplImage2DAtomicCoords:
4901 		{
4902 			if (msl_options.supports_msl_version(1, 2))
4903 			{
4904 				statement("// The required alignment of a linear texture of R32Uint format.");
4905 				statement("constant uint spvLinearTextureAlignmentOverride [[function_constant(",
4906 				          msl_options.r32ui_alignment_constant_id, ")]];");
4907 				statement("constant uint spvLinearTextureAlignment = ",
4908 				          "is_function_constant_defined(spvLinearTextureAlignmentOverride) ? ",
4909 				          "spvLinearTextureAlignmentOverride : ", msl_options.r32ui_linear_texture_alignment, ";");
4910 			}
4911 			else
4912 			{
4913 				statement("// The required alignment of a linear texture of R32Uint format.");
4914 				statement("constant uint spvLinearTextureAlignment = ", msl_options.r32ui_linear_texture_alignment,
4915 				          ";");
4916 			}
4917 			statement("// Returns buffer coords corresponding to 2D texture coords for emulating 2D texture atomics");
4918 			statement("#define spvImage2DAtomicCoord(tc, tex) (((((tex).get_width() + ",
4919 			          " spvLinearTextureAlignment / 4 - 1) & ~(",
4920 			          " spvLinearTextureAlignment / 4 - 1)) * (tc).y) + (tc).x)");
4921 			statement("");
4922 			break;
4923 		}
4924 
4925 		// "fadd" intrinsic support
4926 		case SPVFuncImplFAdd:
4927 			statement("template<typename T>");
4928 			statement("T spvFAdd(T l, T r)");
4929 			begin_scope();
4930 			statement("return fma(T(1), l, r);");
4931 			end_scope();
4932 			statement("");
4933 			break;
4934 
4935 		// "fsub" intrinsic support
4936 		case SPVFuncImplFSub:
4937 			statement("template<typename T>");
4938 			statement("T spvFSub(T l, T r)");
4939 			begin_scope();
4940 			statement("return fma(T(-1), r, l);");
4941 			end_scope();
4942 			statement("");
4943 			break;
4944 
4945 		// "fmul' intrinsic support
4946 		case SPVFuncImplFMul:
4947 			statement("template<typename T>");
4948 			statement("T spvFMul(T l, T r)");
4949 			begin_scope();
4950 			statement("return fma(l, r, T(0));");
4951 			end_scope();
4952 			statement("");
4953 
4954 			statement("template<typename T, int Cols, int Rows>");
4955 			statement("vec<T, Cols> spvFMulVectorMatrix(vec<T, Rows> v, matrix<T, Cols, Rows> m)");
4956 			begin_scope();
4957 			statement("vec<T, Cols> res = vec<T, Cols>(0);");
4958 			statement("for (uint i = Rows; i > 0; --i)");
4959 			begin_scope();
4960 			statement("vec<T, Cols> tmp(0);");
4961 			statement("for (uint j = 0; j < Cols; ++j)");
4962 			begin_scope();
4963 			statement("tmp[j] = m[j][i - 1];");
4964 			end_scope();
4965 			statement("res = fma(tmp, vec<T, Cols>(v[i - 1]), res);");
4966 			end_scope();
4967 			statement("return res;");
4968 			end_scope();
4969 			statement("");
4970 
4971 			statement("template<typename T, int Cols, int Rows>");
4972 			statement("vec<T, Rows> spvFMulMatrixVector(matrix<T, Cols, Rows> m, vec<T, Cols> v)");
4973 			begin_scope();
4974 			statement("vec<T, Rows> res = vec<T, Rows>(0);");
4975 			statement("for (uint i = Cols; i > 0; --i)");
4976 			begin_scope();
4977 			statement("res = fma(m[i - 1], vec<T, Rows>(v[i - 1]), res);");
4978 			end_scope();
4979 			statement("return res;");
4980 			end_scope();
4981 			statement("");
4982 
4983 			statement("template<typename T, int LCols, int LRows, int RCols, int RRows>");
4984 			statement(
4985 			    "matrix<T, RCols, LRows> spvFMulMatrixMatrix(matrix<T, LCols, LRows> l, matrix<T, RCols, RRows> r)");
4986 			begin_scope();
4987 			statement("matrix<T, RCols, LRows> res;");
4988 			statement("for (uint i = 0; i < RCols; i++)");
4989 			begin_scope();
4990 			statement("vec<T, RCols> tmp(0);");
4991 			statement("for (uint j = 0; j < LCols; j++)");
4992 			begin_scope();
4993 			statement("tmp = fma(vec<T, RCols>(r[i][j]), l[j], tmp);");
4994 			end_scope();
4995 			statement("res[i] = tmp;");
4996 			end_scope();
4997 			statement("return res;");
4998 			end_scope();
4999 			statement("");
5000 			break;
5001 
5002 		// Emulate texturecube_array with texture2d_array for iOS where this type is not available
5003 		case SPVFuncImplCubemapTo2DArrayFace:
5004 			statement(force_inline);
5005 			statement("float3 spvCubemapTo2DArrayFace(float3 P)");
5006 			begin_scope();
5007 			statement("float3 Coords = abs(P.xyz);");
5008 			statement("float CubeFace = 0;");
5009 			statement("float ProjectionAxis = 0;");
5010 			statement("float u = 0;");
5011 			statement("float v = 0;");
5012 			statement("if (Coords.x >= Coords.y && Coords.x >= Coords.z)");
5013 			begin_scope();
5014 			statement("CubeFace = P.x >= 0 ? 0 : 1;");
5015 			statement("ProjectionAxis = Coords.x;");
5016 			statement("u = P.x >= 0 ? -P.z : P.z;");
5017 			statement("v = -P.y;");
5018 			end_scope();
5019 			statement("else if (Coords.y >= Coords.x && Coords.y >= Coords.z)");
5020 			begin_scope();
5021 			statement("CubeFace = P.y >= 0 ? 2 : 3;");
5022 			statement("ProjectionAxis = Coords.y;");
5023 			statement("u = P.x;");
5024 			statement("v = P.y >= 0 ? P.z : -P.z;");
5025 			end_scope();
5026 			statement("else");
5027 			begin_scope();
5028 			statement("CubeFace = P.z >= 0 ? 4 : 5;");
5029 			statement("ProjectionAxis = Coords.z;");
5030 			statement("u = P.z >= 0 ? P.x : -P.x;");
5031 			statement("v = -P.y;");
5032 			end_scope();
5033 			statement("u = 0.5 * (u/ProjectionAxis + 1);");
5034 			statement("v = 0.5 * (v/ProjectionAxis + 1);");
5035 			statement("return float3(u, v, CubeFace);");
5036 			end_scope();
5037 			statement("");
5038 			break;
5039 
5040 		case SPVFuncImplInverse4x4:
5041 			statement("// Returns the determinant of a 2x2 matrix.");
5042 			statement(force_inline);
5043 			statement("float spvDet2x2(float a1, float a2, float b1, float b2)");
5044 			begin_scope();
5045 			statement("return a1 * b2 - b1 * a2;");
5046 			end_scope();
5047 			statement("");
5048 
5049 			statement("// Returns the determinant of a 3x3 matrix.");
5050 			statement(force_inline);
5051 			statement("float spvDet3x3(float a1, float a2, float a3, float b1, float b2, float b3, float c1, "
5052 			          "float c2, float c3)");
5053 			begin_scope();
5054 			statement("return a1 * spvDet2x2(b2, b3, c2, c3) - b1 * spvDet2x2(a2, a3, c2, c3) + c1 * spvDet2x2(a2, a3, "
5055 			          "b2, b3);");
5056 			end_scope();
5057 			statement("");
5058 			statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
5059 			statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
5060 			statement(force_inline);
5061 			statement("float4x4 spvInverse4x4(float4x4 m)");
5062 			begin_scope();
5063 			statement("float4x4 adj;	// The adjoint matrix (inverse after dividing by determinant)");
5064 			statement_no_indent("");
5065 			statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
5066 			statement("adj[0][0] =  spvDet3x3(m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], "
5067 			          "m[3][3]);");
5068 			statement("adj[0][1] = -spvDet3x3(m[0][1], m[0][2], m[0][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], "
5069 			          "m[3][3]);");
5070 			statement("adj[0][2] =  spvDet3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[3][1], m[3][2], "
5071 			          "m[3][3]);");
5072 			statement("adj[0][3] = -spvDet3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], "
5073 			          "m[2][3]);");
5074 			statement_no_indent("");
5075 			statement("adj[1][0] = -spvDet3x3(m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], "
5076 			          "m[3][3]);");
5077 			statement("adj[1][1] =  spvDet3x3(m[0][0], m[0][2], m[0][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], "
5078 			          "m[3][3]);");
5079 			statement("adj[1][2] = -spvDet3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[3][0], m[3][2], "
5080 			          "m[3][3]);");
5081 			statement("adj[1][3] =  spvDet3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], "
5082 			          "m[2][3]);");
5083 			statement_no_indent("");
5084 			statement("adj[2][0] =  spvDet3x3(m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], "
5085 			          "m[3][3]);");
5086 			statement("adj[2][1] = -spvDet3x3(m[0][0], m[0][1], m[0][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], "
5087 			          "m[3][3]);");
5088 			statement("adj[2][2] =  spvDet3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[3][0], m[3][1], "
5089 			          "m[3][3]);");
5090 			statement("adj[2][3] = -spvDet3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], "
5091 			          "m[2][3]);");
5092 			statement_no_indent("");
5093 			statement("adj[3][0] = -spvDet3x3(m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], "
5094 			          "m[3][2]);");
5095 			statement("adj[3][1] =  spvDet3x3(m[0][0], m[0][1], m[0][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], "
5096 			          "m[3][2]);");
5097 			statement("adj[3][2] = -spvDet3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[3][0], m[3][1], "
5098 			          "m[3][2]);");
5099 			statement("adj[3][3] =  spvDet3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], "
5100 			          "m[2][2]);");
5101 			statement_no_indent("");
5102 			statement("// Calculate the determinant as a combination of the cofactors of the first row.");
5103 			statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]) + (adj[0][3] "
5104 			          "* m[3][0]);");
5105 			statement_no_indent("");
5106 			statement("// Divide the classical adjoint matrix by the determinant.");
5107 			statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
5108 			statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
5109 			end_scope();
5110 			statement("");
5111 			break;
5112 
5113 		case SPVFuncImplInverse3x3:
5114 			if (spv_function_implementations.count(SPVFuncImplInverse4x4) == 0)
5115 			{
5116 				statement("// Returns the determinant of a 2x2 matrix.");
5117 				statement(force_inline);
5118 				statement("float spvDet2x2(float a1, float a2, float b1, float b2)");
5119 				begin_scope();
5120 				statement("return a1 * b2 - b1 * a2;");
5121 				end_scope();
5122 				statement("");
5123 			}
5124 
5125 			statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
5126 			statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
5127 			statement(force_inline);
5128 			statement("float3x3 spvInverse3x3(float3x3 m)");
5129 			begin_scope();
5130 			statement("float3x3 adj;	// The adjoint matrix (inverse after dividing by determinant)");
5131 			statement_no_indent("");
5132 			statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
5133 			statement("adj[0][0] =  spvDet2x2(m[1][1], m[1][2], m[2][1], m[2][2]);");
5134 			statement("adj[0][1] = -spvDet2x2(m[0][1], m[0][2], m[2][1], m[2][2]);");
5135 			statement("adj[0][2] =  spvDet2x2(m[0][1], m[0][2], m[1][1], m[1][2]);");
5136 			statement_no_indent("");
5137 			statement("adj[1][0] = -spvDet2x2(m[1][0], m[1][2], m[2][0], m[2][2]);");
5138 			statement("adj[1][1] =  spvDet2x2(m[0][0], m[0][2], m[2][0], m[2][2]);");
5139 			statement("adj[1][2] = -spvDet2x2(m[0][0], m[0][2], m[1][0], m[1][2]);");
5140 			statement_no_indent("");
5141 			statement("adj[2][0] =  spvDet2x2(m[1][0], m[1][1], m[2][0], m[2][1]);");
5142 			statement("adj[2][1] = -spvDet2x2(m[0][0], m[0][1], m[2][0], m[2][1]);");
5143 			statement("adj[2][2] =  spvDet2x2(m[0][0], m[0][1], m[1][0], m[1][1]);");
5144 			statement_no_indent("");
5145 			statement("// Calculate the determinant as a combination of the cofactors of the first row.");
5146 			statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]);");
5147 			statement_no_indent("");
5148 			statement("// Divide the classical adjoint matrix by the determinant.");
5149 			statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
5150 			statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
5151 			end_scope();
5152 			statement("");
5153 			break;
5154 
5155 		case SPVFuncImplInverse2x2:
5156 			statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
5157 			statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
5158 			statement(force_inline);
5159 			statement("float2x2 spvInverse2x2(float2x2 m)");
5160 			begin_scope();
5161 			statement("float2x2 adj;	// The adjoint matrix (inverse after dividing by determinant)");
5162 			statement_no_indent("");
5163 			statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
5164 			statement("adj[0][0] =  m[1][1];");
5165 			statement("adj[0][1] = -m[0][1];");
5166 			statement_no_indent("");
5167 			statement("adj[1][0] = -m[1][0];");
5168 			statement("adj[1][1] =  m[0][0];");
5169 			statement_no_indent("");
5170 			statement("// Calculate the determinant as a combination of the cofactors of the first row.");
5171 			statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]);");
5172 			statement_no_indent("");
5173 			statement("// Divide the classical adjoint matrix by the determinant.");
5174 			statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
5175 			statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
5176 			end_scope();
5177 			statement("");
5178 			break;
5179 
5180 		case SPVFuncImplForwardArgs:
5181 			statement("template<typename T> struct spvRemoveReference { typedef T type; };");
5182 			statement("template<typename T> struct spvRemoveReference<thread T&> { typedef T type; };");
5183 			statement("template<typename T> struct spvRemoveReference<thread T&&> { typedef T type; };");
5184 			statement("template<typename T> inline constexpr thread T&& spvForward(thread typename "
5185 			          "spvRemoveReference<T>::type& x)");
5186 			begin_scope();
5187 			statement("return static_cast<thread T&&>(x);");
5188 			end_scope();
5189 			statement("template<typename T> inline constexpr thread T&& spvForward(thread typename "
5190 			          "spvRemoveReference<T>::type&& x)");
5191 			begin_scope();
5192 			statement("return static_cast<thread T&&>(x);");
5193 			end_scope();
5194 			statement("");
5195 			break;
5196 
5197 		case SPVFuncImplGetSwizzle:
5198 			statement("enum class spvSwizzle : uint");
5199 			begin_scope();
5200 			statement("none = 0,");
5201 			statement("zero,");
5202 			statement("one,");
5203 			statement("red,");
5204 			statement("green,");
5205 			statement("blue,");
5206 			statement("alpha");
5207 			end_scope_decl();
5208 			statement("");
5209 			statement("template<typename T>");
5210 			statement("inline T spvGetSwizzle(vec<T, 4> x, T c, spvSwizzle s)");
5211 			begin_scope();
5212 			statement("switch (s)");
5213 			begin_scope();
5214 			statement("case spvSwizzle::none:");
5215 			statement("    return c;");
5216 			statement("case spvSwizzle::zero:");
5217 			statement("    return 0;");
5218 			statement("case spvSwizzle::one:");
5219 			statement("    return 1;");
5220 			statement("case spvSwizzle::red:");
5221 			statement("    return x.r;");
5222 			statement("case spvSwizzle::green:");
5223 			statement("    return x.g;");
5224 			statement("case spvSwizzle::blue:");
5225 			statement("    return x.b;");
5226 			statement("case spvSwizzle::alpha:");
5227 			statement("    return x.a;");
5228 			end_scope();
5229 			end_scope();
5230 			statement("");
5231 			break;
5232 
5233 		case SPVFuncImplTextureSwizzle:
5234 			statement("// Wrapper function that swizzles texture samples and fetches.");
5235 			statement("template<typename T>");
5236 			statement("inline vec<T, 4> spvTextureSwizzle(vec<T, 4> x, uint s)");
5237 			begin_scope();
5238 			statement("if (!s)");
5239 			statement("    return x;");
5240 			statement("return vec<T, 4>(spvGetSwizzle(x, x.r, spvSwizzle((s >> 0) & 0xFF)), "
5241 			          "spvGetSwizzle(x, x.g, spvSwizzle((s >> 8) & 0xFF)), spvGetSwizzle(x, x.b, spvSwizzle((s >> 16) "
5242 			          "& 0xFF)), "
5243 			          "spvGetSwizzle(x, x.a, spvSwizzle((s >> 24) & 0xFF)));");
5244 			end_scope();
5245 			statement("");
5246 			statement("template<typename T>");
5247 			statement("inline T spvTextureSwizzle(T x, uint s)");
5248 			begin_scope();
5249 			statement("return spvTextureSwizzle(vec<T, 4>(x, 0, 0, 1), s).x;");
5250 			end_scope();
5251 			statement("");
5252 			break;
5253 
5254 		case SPVFuncImplGatherSwizzle:
5255 			statement("// Wrapper function that swizzles texture gathers.");
5256 			statement("template<typename T, template<typename, access = access::sample, typename = void> class Tex, "
5257 			          "typename... Ts>");
5258 			statement("inline vec<T, 4> spvGatherSwizzle(const thread Tex<T>& t, sampler s, "
5259 			          "uint sw, component c, Ts... params) METAL_CONST_ARG(c)");
5260 			begin_scope();
5261 			statement("if (sw)");
5262 			begin_scope();
5263 			statement("switch (spvSwizzle((sw >> (uint(c) * 8)) & 0xFF))");
5264 			begin_scope();
5265 			statement("case spvSwizzle::none:");
5266 			statement("    break;");
5267 			statement("case spvSwizzle::zero:");
5268 			statement("    return vec<T, 4>(0, 0, 0, 0);");
5269 			statement("case spvSwizzle::one:");
5270 			statement("    return vec<T, 4>(1, 1, 1, 1);");
5271 			statement("case spvSwizzle::red:");
5272 			statement("    return t.gather(s, spvForward<Ts>(params)..., component::x);");
5273 			statement("case spvSwizzle::green:");
5274 			statement("    return t.gather(s, spvForward<Ts>(params)..., component::y);");
5275 			statement("case spvSwizzle::blue:");
5276 			statement("    return t.gather(s, spvForward<Ts>(params)..., component::z);");
5277 			statement("case spvSwizzle::alpha:");
5278 			statement("    return t.gather(s, spvForward<Ts>(params)..., component::w);");
5279 			end_scope();
5280 			end_scope();
5281 			// texture::gather insists on its component parameter being a constant
5282 			// expression, so we need this silly workaround just to compile the shader.
5283 			statement("switch (c)");
5284 			begin_scope();
5285 			statement("case component::x:");
5286 			statement("    return t.gather(s, spvForward<Ts>(params)..., component::x);");
5287 			statement("case component::y:");
5288 			statement("    return t.gather(s, spvForward<Ts>(params)..., component::y);");
5289 			statement("case component::z:");
5290 			statement("    return t.gather(s, spvForward<Ts>(params)..., component::z);");
5291 			statement("case component::w:");
5292 			statement("    return t.gather(s, spvForward<Ts>(params)..., component::w);");
5293 			end_scope();
5294 			end_scope();
5295 			statement("");
5296 			break;
5297 
5298 		case SPVFuncImplGatherCompareSwizzle:
5299 			statement("// Wrapper function that swizzles depth texture gathers.");
5300 			statement("template<typename T, template<typename, access = access::sample, typename = void> class Tex, "
5301 			          "typename... Ts>");
5302 			statement("inline vec<T, 4> spvGatherCompareSwizzle(const thread Tex<T>& t, sampler "
5303 			          "s, uint sw, Ts... params) ");
5304 			begin_scope();
5305 			statement("if (sw)");
5306 			begin_scope();
5307 			statement("switch (spvSwizzle(sw & 0xFF))");
5308 			begin_scope();
5309 			statement("case spvSwizzle::none:");
5310 			statement("case spvSwizzle::red:");
5311 			statement("    break;");
5312 			statement("case spvSwizzle::zero:");
5313 			statement("case spvSwizzle::green:");
5314 			statement("case spvSwizzle::blue:");
5315 			statement("case spvSwizzle::alpha:");
5316 			statement("    return vec<T, 4>(0, 0, 0, 0);");
5317 			statement("case spvSwizzle::one:");
5318 			statement("    return vec<T, 4>(1, 1, 1, 1);");
5319 			end_scope();
5320 			end_scope();
5321 			statement("return t.gather_compare(s, spvForward<Ts>(params)...);");
5322 			end_scope();
5323 			statement("");
5324 			break;
5325 
5326 		case SPVFuncImplSubgroupBroadcast:
5327 			// Metal doesn't allow broadcasting boolean values directly, but we can work around that by broadcasting
5328 			// them as integers.
5329 			statement("template<typename T>");
5330 			statement("inline T spvSubgroupBroadcast(T value, ushort lane)");
5331 			begin_scope();
5332 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5333 				statement("return quad_broadcast(value, lane);");
5334 			else
5335 				statement("return simd_broadcast(value, lane);");
5336 			end_scope();
5337 			statement("");
5338 			statement("template<>");
5339 			statement("inline bool spvSubgroupBroadcast(bool value, ushort lane)");
5340 			begin_scope();
5341 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5342 				statement("return !!quad_broadcast((ushort)value, lane);");
5343 			else
5344 				statement("return !!simd_broadcast((ushort)value, lane);");
5345 			end_scope();
5346 			statement("");
5347 			statement("template<uint N>");
5348 			statement("inline vec<bool, N> spvSubgroupBroadcast(vec<bool, N> value, ushort lane)");
5349 			begin_scope();
5350 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5351 				statement("return (vec<bool, N>)quad_broadcast((vec<ushort, N>)value, lane);");
5352 			else
5353 				statement("return (vec<bool, N>)simd_broadcast((vec<ushort, N>)value, lane);");
5354 			end_scope();
5355 			statement("");
5356 			break;
5357 
5358 		case SPVFuncImplSubgroupBroadcastFirst:
5359 			statement("template<typename T>");
5360 			statement("inline T spvSubgroupBroadcastFirst(T value)");
5361 			begin_scope();
5362 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5363 				statement("return quad_broadcast_first(value);");
5364 			else
5365 				statement("return simd_broadcast_first(value);");
5366 			end_scope();
5367 			statement("");
5368 			statement("template<>");
5369 			statement("inline bool spvSubgroupBroadcastFirst(bool value)");
5370 			begin_scope();
5371 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5372 				statement("return !!quad_broadcast_first((ushort)value);");
5373 			else
5374 				statement("return !!simd_broadcast_first((ushort)value);");
5375 			end_scope();
5376 			statement("");
5377 			statement("template<uint N>");
5378 			statement("inline vec<bool, N> spvSubgroupBroadcastFirst(vec<bool, N> value)");
5379 			begin_scope();
5380 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5381 				statement("return (vec<bool, N>)quad_broadcast_first((vec<ushort, N>)value);");
5382 			else
5383 				statement("return (vec<bool, N>)simd_broadcast_first((vec<ushort, N>)value);");
5384 			end_scope();
5385 			statement("");
5386 			break;
5387 
5388 		case SPVFuncImplSubgroupBallot:
5389 			statement("inline uint4 spvSubgroupBallot(bool value)");
5390 			begin_scope();
5391 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5392 			{
5393 				statement("return uint4((quad_vote::vote_t)quad_ballot(value), 0, 0, 0);");
5394 			}
5395 			else if (msl_options.is_ios())
5396 			{
5397 				// The current simd_vote on iOS uses a 32-bit integer-like object.
5398 				statement("return uint4((simd_vote::vote_t)simd_ballot(value), 0, 0, 0);");
5399 			}
5400 			else
5401 			{
5402 				statement("simd_vote vote = simd_ballot(value);");
5403 				statement("// simd_ballot() returns a 64-bit integer-like object, but");
5404 				statement("// SPIR-V callers expect a uint4. We must convert.");
5405 				statement("// FIXME: This won't include higher bits if Apple ever supports");
5406 				statement("// 128 lanes in an SIMD-group.");
5407 				statement("return uint4(as_type<uint2>((simd_vote::vote_t)vote), 0, 0);");
5408 			}
5409 			end_scope();
5410 			statement("");
5411 			break;
5412 
5413 		case SPVFuncImplSubgroupBallotBitExtract:
5414 			statement("inline bool spvSubgroupBallotBitExtract(uint4 ballot, uint bit)");
5415 			begin_scope();
5416 			statement("return !!extract_bits(ballot[bit / 32], bit % 32, 1);");
5417 			end_scope();
5418 			statement("");
5419 			break;
5420 
5421 		case SPVFuncImplSubgroupBallotFindLSB:
5422 			statement("inline uint spvSubgroupBallotFindLSB(uint4 ballot, uint gl_SubgroupSize)");
5423 			begin_scope();
5424 			if (msl_options.is_ios())
5425 			{
5426 				statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupSize), uint3(0));");
5427 			}
5428 			else
5429 			{
5430 				statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), "
5431 				          "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));");
5432 			}
5433 			statement("ballot &= mask;");
5434 			statement("return select(ctz(ballot.x), select(32 + ctz(ballot.y), select(64 + ctz(ballot.z), select(96 + "
5435 			          "ctz(ballot.w), uint(-1), ballot.w == 0), ballot.z == 0), ballot.y == 0), ballot.x == 0);");
5436 			end_scope();
5437 			statement("");
5438 			break;
5439 
5440 		case SPVFuncImplSubgroupBallotFindMSB:
5441 			statement("inline uint spvSubgroupBallotFindMSB(uint4 ballot, uint gl_SubgroupSize)");
5442 			begin_scope();
5443 			if (msl_options.is_ios())
5444 			{
5445 				statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupSize), uint3(0));");
5446 			}
5447 			else
5448 			{
5449 				statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), "
5450 				          "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));");
5451 			}
5452 			statement("ballot &= mask;");
5453 			statement("return select(128 - (clz(ballot.w) + 1), select(96 - (clz(ballot.z) + 1), select(64 - "
5454 			          "(clz(ballot.y) + 1), select(32 - (clz(ballot.x) + 1), uint(-1), ballot.x == 0), ballot.y == 0), "
5455 			          "ballot.z == 0), ballot.w == 0);");
5456 			end_scope();
5457 			statement("");
5458 			break;
5459 
5460 		case SPVFuncImplSubgroupBallotBitCount:
5461 			statement("inline uint spvPopCount4(uint4 ballot)");
5462 			begin_scope();
5463 			statement("return popcount(ballot.x) + popcount(ballot.y) + popcount(ballot.z) + popcount(ballot.w);");
5464 			end_scope();
5465 			statement("");
5466 			statement("inline uint spvSubgroupBallotBitCount(uint4 ballot, uint gl_SubgroupSize)");
5467 			begin_scope();
5468 			if (msl_options.is_ios())
5469 			{
5470 				statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupSize), uint3(0));");
5471 			}
5472 			else
5473 			{
5474 				statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), "
5475 				          "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));");
5476 			}
5477 			statement("return spvPopCount4(ballot & mask);");
5478 			end_scope();
5479 			statement("");
5480 			statement("inline uint spvSubgroupBallotInclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)");
5481 			begin_scope();
5482 			if (msl_options.is_ios())
5483 			{
5484 				statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupInvocationID + 1), uint3(0));");
5485 			}
5486 			else
5487 			{
5488 				statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID + 1, 32u)), "
5489 				          "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0)), "
5490 				          "uint2(0));");
5491 			}
5492 			statement("return spvPopCount4(ballot & mask);");
5493 			end_scope();
5494 			statement("");
5495 			statement("inline uint spvSubgroupBallotExclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)");
5496 			begin_scope();
5497 			if (msl_options.is_ios())
5498 			{
5499 				statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupInvocationID), uint2(0));");
5500 			}
5501 			else
5502 			{
5503 				statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID, 32u)), "
5504 				          "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID - 32, 0)), uint2(0));");
5505 			}
5506 			statement("return spvPopCount4(ballot & mask);");
5507 			end_scope();
5508 			statement("");
5509 			break;
5510 
5511 		case SPVFuncImplSubgroupAllEqual:
5512 			// Metal doesn't provide a function to evaluate this directly. But, we can
5513 			// implement this by comparing every thread's value to one thread's value
5514 			// (in this case, the value of the first active thread). Then, by the transitive
5515 			// property of equality, if all comparisons return true, then they are all equal.
5516 			statement("template<typename T>");
5517 			statement("inline bool spvSubgroupAllEqual(T value)");
5518 			begin_scope();
5519 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5520 				statement("return quad_all(all(value == quad_broadcast_first(value)));");
5521 			else
5522 				statement("return simd_all(all(value == simd_broadcast_first(value)));");
5523 			end_scope();
5524 			statement("");
5525 			statement("template<>");
5526 			statement("inline bool spvSubgroupAllEqual(bool value)");
5527 			begin_scope();
5528 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5529 				statement("return quad_all(value) || !quad_any(value);");
5530 			else
5531 				statement("return simd_all(value) || !simd_any(value);");
5532 			end_scope();
5533 			statement("");
5534 			statement("template<uint N>");
5535 			statement("inline bool spvSubgroupAllEqual(vec<bool, N> value)");
5536 			begin_scope();
5537 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5538 				statement("return quad_all(all(value == (vec<bool, N>)quad_broadcast_first((vec<ushort, N>)value)));");
5539 			else
5540 				statement("return simd_all(all(value == (vec<bool, N>)simd_broadcast_first((vec<ushort, N>)value)));");
5541 			end_scope();
5542 			statement("");
5543 			break;
5544 
5545 		case SPVFuncImplSubgroupShuffle:
5546 			statement("template<typename T>");
5547 			statement("inline T spvSubgroupShuffle(T value, ushort lane)");
5548 			begin_scope();
5549 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5550 				statement("return quad_shuffle(value, lane);");
5551 			else
5552 				statement("return simd_shuffle(value, lane);");
5553 			end_scope();
5554 			statement("");
5555 			statement("template<>");
5556 			statement("inline bool spvSubgroupShuffle(bool value, ushort lane)");
5557 			begin_scope();
5558 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5559 				statement("return !!quad_shuffle((ushort)value, lane);");
5560 			else
5561 				statement("return !!simd_shuffle((ushort)value, lane);");
5562 			end_scope();
5563 			statement("");
5564 			statement("template<uint N>");
5565 			statement("inline vec<bool, N> spvSubgroupShuffle(vec<bool, N> value, ushort lane)");
5566 			begin_scope();
5567 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5568 				statement("return (vec<bool, N>)quad_shuffle((vec<ushort, N>)value, lane);");
5569 			else
5570 				statement("return (vec<bool, N>)simd_shuffle((vec<ushort, N>)value, lane);");
5571 			end_scope();
5572 			statement("");
5573 			break;
5574 
5575 		case SPVFuncImplSubgroupShuffleXor:
5576 			statement("template<typename T>");
5577 			statement("inline T spvSubgroupShuffleXor(T value, ushort mask)");
5578 			begin_scope();
5579 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5580 				statement("return quad_shuffle_xor(value, mask);");
5581 			else
5582 				statement("return simd_shuffle_xor(value, mask);");
5583 			end_scope();
5584 			statement("");
5585 			statement("template<>");
5586 			statement("inline bool spvSubgroupShuffleXor(bool value, ushort mask)");
5587 			begin_scope();
5588 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5589 				statement("return !!quad_shuffle_xor((ushort)value, mask);");
5590 			else
5591 				statement("return !!simd_shuffle_xor((ushort)value, mask);");
5592 			end_scope();
5593 			statement("");
5594 			statement("template<uint N>");
5595 			statement("inline vec<bool, N> spvSubgroupShuffleXor(vec<bool, N> value, ushort mask)");
5596 			begin_scope();
5597 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5598 				statement("return (vec<bool, N>)quad_shuffle_xor((vec<ushort, N>)value, mask);");
5599 			else
5600 				statement("return (vec<bool, N>)simd_shuffle_xor((vec<ushort, N>)value, mask);");
5601 			end_scope();
5602 			statement("");
5603 			break;
5604 
5605 		case SPVFuncImplSubgroupShuffleUp:
5606 			statement("template<typename T>");
5607 			statement("inline T spvSubgroupShuffleUp(T value, ushort delta)");
5608 			begin_scope();
5609 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5610 				statement("return quad_shuffle_up(value, delta);");
5611 			else
5612 				statement("return simd_shuffle_up(value, delta);");
5613 			end_scope();
5614 			statement("");
5615 			statement("template<>");
5616 			statement("inline bool spvSubgroupShuffleUp(bool value, ushort delta)");
5617 			begin_scope();
5618 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5619 				statement("return !!quad_shuffle_up((ushort)value, delta);");
5620 			else
5621 				statement("return !!simd_shuffle_up((ushort)value, delta);");
5622 			end_scope();
5623 			statement("");
5624 			statement("template<uint N>");
5625 			statement("inline vec<bool, N> spvSubgroupShuffleUp(vec<bool, N> value, ushort delta)");
5626 			begin_scope();
5627 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5628 				statement("return (vec<bool, N>)quad_shuffle_up((vec<ushort, N>)value, delta);");
5629 			else
5630 				statement("return (vec<bool, N>)simd_shuffle_up((vec<ushort, N>)value, delta);");
5631 			end_scope();
5632 			statement("");
5633 			break;
5634 
5635 		case SPVFuncImplSubgroupShuffleDown:
5636 			statement("template<typename T>");
5637 			statement("inline T spvSubgroupShuffleDown(T value, ushort delta)");
5638 			begin_scope();
5639 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5640 				statement("return quad_shuffle_down(value, delta);");
5641 			else
5642 				statement("return simd_shuffle_down(value, delta);");
5643 			end_scope();
5644 			statement("");
5645 			statement("template<>");
5646 			statement("inline bool spvSubgroupShuffleDown(bool value, ushort delta)");
5647 			begin_scope();
5648 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5649 				statement("return !!quad_shuffle_down((ushort)value, delta);");
5650 			else
5651 				statement("return !!simd_shuffle_down((ushort)value, delta);");
5652 			end_scope();
5653 			statement("");
5654 			statement("template<uint N>");
5655 			statement("inline vec<bool, N> spvSubgroupShuffleDown(vec<bool, N> value, ushort delta)");
5656 			begin_scope();
5657 			if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5658 				statement("return (vec<bool, N>)quad_shuffle_down((vec<ushort, N>)value, delta);");
5659 			else
5660 				statement("return (vec<bool, N>)simd_shuffle_down((vec<ushort, N>)value, delta);");
5661 			end_scope();
5662 			statement("");
5663 			break;
5664 
5665 		case SPVFuncImplQuadBroadcast:
5666 			statement("template<typename T>");
5667 			statement("inline T spvQuadBroadcast(T value, uint lane)");
5668 			begin_scope();
5669 			statement("return quad_broadcast(value, lane);");
5670 			end_scope();
5671 			statement("");
5672 			statement("template<>");
5673 			statement("inline bool spvQuadBroadcast(bool value, uint lane)");
5674 			begin_scope();
5675 			statement("return !!quad_broadcast((ushort)value, lane);");
5676 			end_scope();
5677 			statement("");
5678 			statement("template<uint N>");
5679 			statement("inline vec<bool, N> spvQuadBroadcast(vec<bool, N> value, uint lane)");
5680 			begin_scope();
5681 			statement("return (vec<bool, N>)quad_broadcast((vec<ushort, N>)value, lane);");
5682 			end_scope();
5683 			statement("");
5684 			break;
5685 
5686 		case SPVFuncImplQuadSwap:
5687 			// We can implement this easily based on the following table giving
5688 			// the target lane ID from the direction and current lane ID:
5689 			//        Direction
5690 			//      | 0 | 1 | 2 |
5691 			//   ---+---+---+---+
5692 			// L 0  | 1   2   3
5693 			// a 1  | 0   3   2
5694 			// n 2  | 3   0   1
5695 			// e 3  | 2   1   0
5696 			// Notice that target = source ^ (direction + 1).
5697 			statement("template<typename T>");
5698 			statement("inline T spvQuadSwap(T value, uint dir)");
5699 			begin_scope();
5700 			statement("return quad_shuffle_xor(value, dir + 1);");
5701 			end_scope();
5702 			statement("");
5703 			statement("template<>");
5704 			statement("inline bool spvQuadSwap(bool value, uint dir)");
5705 			begin_scope();
5706 			statement("return !!quad_shuffle_xor((ushort)value, dir + 1);");
5707 			end_scope();
5708 			statement("");
5709 			statement("template<uint N>");
5710 			statement("inline vec<bool, N> spvQuadSwap(vec<bool, N> value, uint dir)");
5711 			begin_scope();
5712 			statement("return (vec<bool, N>)quad_shuffle_xor((vec<ushort, N>)value, dir + 1);");
5713 			end_scope();
5714 			statement("");
5715 			break;
5716 
5717 		case SPVFuncImplReflectScalar:
5718 			// Metal does not support scalar versions of these functions.
5719 			statement("template<typename T>");
5720 			statement("inline T spvReflect(T i, T n)");
5721 			begin_scope();
5722 			statement("return i - T(2) * i * n * n;");
5723 			end_scope();
5724 			statement("");
5725 			break;
5726 
5727 		case SPVFuncImplRefractScalar:
5728 			// Metal does not support scalar versions of these functions.
5729 			statement("template<typename T>");
5730 			statement("inline T spvRefract(T i, T n, T eta)");
5731 			begin_scope();
5732 			statement("T NoI = n * i;");
5733 			statement("T NoI2 = NoI * NoI;");
5734 			statement("T k = T(1) - eta * eta * (T(1) - NoI2);");
5735 			statement("if (k < T(0))");
5736 			begin_scope();
5737 			statement("return T(0);");
5738 			end_scope();
5739 			statement("else");
5740 			begin_scope();
5741 			statement("return eta * i - (eta * NoI + sqrt(k)) * n;");
5742 			end_scope();
5743 			end_scope();
5744 			statement("");
5745 			break;
5746 
5747 		case SPVFuncImplFaceForwardScalar:
5748 			// Metal does not support scalar versions of these functions.
5749 			statement("template<typename T>");
5750 			statement("inline T spvFaceForward(T n, T i, T nref)");
5751 			begin_scope();
5752 			statement("return i * nref < T(0) ? n : -n;");
5753 			end_scope();
5754 			statement("");
5755 			break;
5756 
5757 		case SPVFuncImplChromaReconstructNearest2Plane:
5758 			statement("template<typename T, typename... LodOptions>");
5759 			statement("inline vec<T, 4> spvChromaReconstructNearest(texture2d<T> plane0, texture2d<T> plane1, sampler "
5760 			          "samp, float2 coord, LodOptions... options)");
5761 			begin_scope();
5762 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5763 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5764 			statement("ycbcr.br = plane1.sample(samp, coord, spvForward<LodOptions>(options)...).rg;");
5765 			statement("return ycbcr;");
5766 			end_scope();
5767 			statement("");
5768 			break;
5769 
5770 		case SPVFuncImplChromaReconstructNearest3Plane:
5771 			statement("template<typename T, typename... LodOptions>");
5772 			statement("inline vec<T, 4> spvChromaReconstructNearest(texture2d<T> plane0, texture2d<T> plane1, "
5773 			          "texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
5774 			begin_scope();
5775 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5776 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5777 			statement("ycbcr.b = plane1.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5778 			statement("ycbcr.r = plane2.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5779 			statement("return ycbcr;");
5780 			end_scope();
5781 			statement("");
5782 			break;
5783 
5784 		case SPVFuncImplChromaReconstructLinear422CositedEven2Plane:
5785 			statement("template<typename T, typename... LodOptions>");
5786 			statement("inline vec<T, 4> spvChromaReconstructLinear422CositedEven(texture2d<T> plane0, texture2d<T> "
5787 			          "plane1, sampler samp, float2 coord, LodOptions... options)");
5788 			begin_scope();
5789 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5790 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5791 			statement("if (fract(coord.x * plane1.get_width()) != 0.0)");
5792 			begin_scope();
5793 			statement("ycbcr.br = vec<T, 2>(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5794 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), 0.5).rg);");
5795 			end_scope();
5796 			statement("else");
5797 			begin_scope();
5798 			statement("ycbcr.br = plane1.sample(samp, coord, spvForward<LodOptions>(options)...).rg;");
5799 			end_scope();
5800 			statement("return ycbcr;");
5801 			end_scope();
5802 			statement("");
5803 			break;
5804 
5805 		case SPVFuncImplChromaReconstructLinear422CositedEven3Plane:
5806 			statement("template<typename T, typename... LodOptions>");
5807 			statement("inline vec<T, 4> spvChromaReconstructLinear422CositedEven(texture2d<T> plane0, texture2d<T> "
5808 			          "plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
5809 			begin_scope();
5810 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5811 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5812 			statement("if (fract(coord.x * plane1.get_width()) != 0.0)");
5813 			begin_scope();
5814 			statement("ycbcr.b = T(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5815 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), 0.5).r);");
5816 			statement("ycbcr.r = T(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
5817 			          "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), 0.5).r);");
5818 			end_scope();
5819 			statement("else");
5820 			begin_scope();
5821 			statement("ycbcr.b = plane1.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5822 			statement("ycbcr.r = plane2.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5823 			end_scope();
5824 			statement("return ycbcr;");
5825 			end_scope();
5826 			statement("");
5827 			break;
5828 
5829 		case SPVFuncImplChromaReconstructLinear422Midpoint2Plane:
5830 			statement("template<typename T, typename... LodOptions>");
5831 			statement("inline vec<T, 4> spvChromaReconstructLinear422Midpoint(texture2d<T> plane0, texture2d<T> "
5832 			          "plane1, sampler samp, float2 coord, LodOptions... options)");
5833 			begin_scope();
5834 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5835 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5836 			statement("int2 offs = int2(fract(coord.x * plane1.get_width()) != 0.0 ? 1 : -1, 0);");
5837 			statement("ycbcr.br = vec<T, 2>(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5838 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., offs), 0.25).rg);");
5839 			statement("return ycbcr;");
5840 			end_scope();
5841 			statement("");
5842 			break;
5843 
5844 		case SPVFuncImplChromaReconstructLinear422Midpoint3Plane:
5845 			statement("template<typename T, typename... LodOptions>");
5846 			statement("inline vec<T, 4> spvChromaReconstructLinear422Midpoint(texture2d<T> plane0, texture2d<T> "
5847 			          "plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
5848 			begin_scope();
5849 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5850 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5851 			statement("int2 offs = int2(fract(coord.x * plane1.get_width()) != 0.0 ? 1 : -1, 0);");
5852 			statement("ycbcr.b = T(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5853 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., offs), 0.25).r);");
5854 			statement("ycbcr.r = T(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
5855 			          "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., offs), 0.25).r);");
5856 			statement("return ycbcr;");
5857 			end_scope();
5858 			statement("");
5859 			break;
5860 
5861 		case SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven2Plane:
5862 			statement("template<typename T, typename... LodOptions>");
5863 			statement("inline vec<T, 4> spvChromaReconstructLinear420XCositedEvenYCositedEven(texture2d<T> plane0, "
5864 			          "texture2d<T> plane1, sampler samp, float2 coord, LodOptions... options)");
5865 			begin_scope();
5866 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5867 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5868 			statement("float2 ab = fract(round(coord * float2(plane0.get_width(), plane0.get_height())) * 0.5);");
5869 			statement("ycbcr.br = vec<T, 2>(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5870 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5871 			          "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5872 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).rg);");
5873 			statement("return ycbcr;");
5874 			end_scope();
5875 			statement("");
5876 			break;
5877 
5878 		case SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven3Plane:
5879 			statement("template<typename T, typename... LodOptions>");
5880 			statement("inline vec<T, 4> spvChromaReconstructLinear420XCositedEvenYCositedEven(texture2d<T> plane0, "
5881 			          "texture2d<T> plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
5882 			begin_scope();
5883 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5884 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5885 			statement("float2 ab = fract(round(coord * float2(plane0.get_width(), plane0.get_height())) * 0.5);");
5886 			statement("ycbcr.b = T(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5887 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5888 			          "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5889 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
5890 			statement("ycbcr.r = T(mix(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
5891 			          "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5892 			          "mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5893 			          "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
5894 			statement("return ycbcr;");
5895 			end_scope();
5896 			statement("");
5897 			break;
5898 
5899 		case SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven2Plane:
5900 			statement("template<typename T, typename... LodOptions>");
5901 			statement("inline vec<T, 4> spvChromaReconstructLinear420XMidpointYCositedEven(texture2d<T> plane0, "
5902 			          "texture2d<T> plane1, sampler samp, float2 coord, LodOptions... options)");
5903 			begin_scope();
5904 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5905 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5906 			statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0.5, "
5907 			          "0)) * 0.5);");
5908 			statement("ycbcr.br = vec<T, 2>(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5909 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5910 			          "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5911 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).rg);");
5912 			statement("return ycbcr;");
5913 			end_scope();
5914 			statement("");
5915 			break;
5916 
5917 		case SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven3Plane:
5918 			statement("template<typename T, typename... LodOptions>");
5919 			statement("inline vec<T, 4> spvChromaReconstructLinear420XMidpointYCositedEven(texture2d<T> plane0, "
5920 			          "texture2d<T> plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
5921 			begin_scope();
5922 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5923 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5924 			statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0.5, "
5925 			          "0)) * 0.5);");
5926 			statement("ycbcr.b = T(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5927 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5928 			          "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5929 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
5930 			statement("ycbcr.r = T(mix(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
5931 			          "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5932 			          "mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5933 			          "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
5934 			statement("return ycbcr;");
5935 			end_scope();
5936 			statement("");
5937 			break;
5938 
5939 		case SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint2Plane:
5940 			statement("template<typename T, typename... LodOptions>");
5941 			statement("inline vec<T, 4> spvChromaReconstructLinear420XCositedEvenYMidpoint(texture2d<T> plane0, "
5942 			          "texture2d<T> plane1, sampler samp, float2 coord, LodOptions... options)");
5943 			begin_scope();
5944 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5945 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5946 			statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0, "
5947 			          "0.5)) * 0.5);");
5948 			statement("ycbcr.br = vec<T, 2>(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5949 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5950 			          "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5951 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).rg);");
5952 			statement("return ycbcr;");
5953 			end_scope();
5954 			statement("");
5955 			break;
5956 
5957 		case SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint3Plane:
5958 			statement("template<typename T, typename... LodOptions>");
5959 			statement("inline vec<T, 4> spvChromaReconstructLinear420XCositedEvenYMidpoint(texture2d<T> plane0, "
5960 			          "texture2d<T> plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
5961 			begin_scope();
5962 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5963 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5964 			statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0, "
5965 			          "0.5)) * 0.5);");
5966 			statement("ycbcr.b = T(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5967 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5968 			          "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5969 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
5970 			statement("ycbcr.r = T(mix(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
5971 			          "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5972 			          "mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5973 			          "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
5974 			statement("return ycbcr;");
5975 			end_scope();
5976 			statement("");
5977 			break;
5978 
5979 		case SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint2Plane:
5980 			statement("template<typename T, typename... LodOptions>");
5981 			statement("inline vec<T, 4> spvChromaReconstructLinear420XMidpointYMidpoint(texture2d<T> plane0, "
5982 			          "texture2d<T> plane1, sampler samp, float2 coord, LodOptions... options)");
5983 			begin_scope();
5984 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5985 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5986 			statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0.5, "
5987 			          "0.5)) * 0.5);");
5988 			statement("ycbcr.br = vec<T, 2>(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5989 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5990 			          "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5991 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).rg);");
5992 			statement("return ycbcr;");
5993 			end_scope();
5994 			statement("");
5995 			break;
5996 
5997 		case SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint3Plane:
5998 			statement("template<typename T, typename... LodOptions>");
5999 			statement("inline vec<T, 4> spvChromaReconstructLinear420XMidpointYMidpoint(texture2d<T> plane0, "
6000 			          "texture2d<T> plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
6001 			begin_scope();
6002 			statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
6003 			statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
6004 			statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0.5, "
6005 			          "0.5)) * 0.5);");
6006 			statement("ycbcr.b = T(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
6007 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
6008 			          "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
6009 			          "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
6010 			statement("ycbcr.r = T(mix(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
6011 			          "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
6012 			          "mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
6013 			          "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
6014 			statement("return ycbcr;");
6015 			end_scope();
6016 			statement("");
6017 			break;
6018 
6019 		case SPVFuncImplExpandITUFullRange:
6020 			statement("template<typename T>");
6021 			statement("inline vec<T, 4> spvExpandITUFullRange(vec<T, 4> ycbcr, int n)");
6022 			begin_scope();
6023 			statement("ycbcr.br -= exp2(T(n-1))/(exp2(T(n))-1);");
6024 			statement("return ycbcr;");
6025 			end_scope();
6026 			statement("");
6027 			break;
6028 
6029 		case SPVFuncImplExpandITUNarrowRange:
6030 			statement("template<typename T>");
6031 			statement("inline vec<T, 4> spvExpandITUNarrowRange(vec<T, 4> ycbcr, int n)");
6032 			begin_scope();
6033 			statement("ycbcr.g = (ycbcr.g * (exp2(T(n)) - 1) - ldexp(T(16), n - 8))/ldexp(T(219), n - 8);");
6034 			statement("ycbcr.br = (ycbcr.br * (exp2(T(n)) - 1) - ldexp(T(128), n - 8))/ldexp(T(224), n - 8);");
6035 			statement("return ycbcr;");
6036 			end_scope();
6037 			statement("");
6038 			break;
6039 
6040 		case SPVFuncImplConvertYCbCrBT709:
6041 			statement("// cf. Khronos Data Format Specification, section 15.1.1");
6042 			statement("constant float3x3 spvBT709Factors = {{1, 1, 1}, {0, -0.13397432/0.7152, 1.8556}, {1.5748, "
6043 			          "-0.33480248/0.7152, 0}};");
6044 			statement("");
6045 			statement("template<typename T>");
6046 			statement("inline vec<T, 4> spvConvertYCbCrBT709(vec<T, 4> ycbcr)");
6047 			begin_scope();
6048 			statement("vec<T, 4> rgba;");
6049 			statement("rgba.rgb = vec<T, 3>(spvBT709Factors * ycbcr.gbr);");
6050 			statement("rgba.a = ycbcr.a;");
6051 			statement("return rgba;");
6052 			end_scope();
6053 			statement("");
6054 			break;
6055 
6056 		case SPVFuncImplConvertYCbCrBT601:
6057 			statement("// cf. Khronos Data Format Specification, section 15.1.2");
6058 			statement("constant float3x3 spvBT601Factors = {{1, 1, 1}, {0, -0.202008/0.587, 1.772}, {1.402, "
6059 			          "-0.419198/0.587, 0}};");
6060 			statement("");
6061 			statement("template<typename T>");
6062 			statement("inline vec<T, 4> spvConvertYCbCrBT601(vec<T, 4> ycbcr)");
6063 			begin_scope();
6064 			statement("vec<T, 4> rgba;");
6065 			statement("rgba.rgb = vec<T, 3>(spvBT601Factors * ycbcr.gbr);");
6066 			statement("rgba.a = ycbcr.a;");
6067 			statement("return rgba;");
6068 			end_scope();
6069 			statement("");
6070 			break;
6071 
6072 		case SPVFuncImplConvertYCbCrBT2020:
6073 			statement("// cf. Khronos Data Format Specification, section 15.1.3");
6074 			statement("constant float3x3 spvBT2020Factors = {{1, 1, 1}, {0, -0.11156702/0.6780, 1.8814}, {1.4746, "
6075 			          "-0.38737742/0.6780, 0}};");
6076 			statement("");
6077 			statement("template<typename T>");
6078 			statement("inline vec<T, 4> spvConvertYCbCrBT2020(vec<T, 4> ycbcr)");
6079 			begin_scope();
6080 			statement("vec<T, 4> rgba;");
6081 			statement("rgba.rgb = vec<T, 3>(spvBT2020Factors * ycbcr.gbr);");
6082 			statement("rgba.a = ycbcr.a;");
6083 			statement("return rgba;");
6084 			end_scope();
6085 			statement("");
6086 			break;
6087 
6088 		case SPVFuncImplDynamicImageSampler:
6089 			statement("enum class spvFormatResolution");
6090 			begin_scope();
6091 			statement("_444 = 0,");
6092 			statement("_422,");
6093 			statement("_420");
6094 			end_scope_decl();
6095 			statement("");
6096 			statement("enum class spvChromaFilter");
6097 			begin_scope();
6098 			statement("nearest = 0,");
6099 			statement("linear");
6100 			end_scope_decl();
6101 			statement("");
6102 			statement("enum class spvXChromaLocation");
6103 			begin_scope();
6104 			statement("cosited_even = 0,");
6105 			statement("midpoint");
6106 			end_scope_decl();
6107 			statement("");
6108 			statement("enum class spvYChromaLocation");
6109 			begin_scope();
6110 			statement("cosited_even = 0,");
6111 			statement("midpoint");
6112 			end_scope_decl();
6113 			statement("");
6114 			statement("enum class spvYCbCrModelConversion");
6115 			begin_scope();
6116 			statement("rgb_identity = 0,");
6117 			statement("ycbcr_identity,");
6118 			statement("ycbcr_bt_709,");
6119 			statement("ycbcr_bt_601,");
6120 			statement("ycbcr_bt_2020");
6121 			end_scope_decl();
6122 			statement("");
6123 			statement("enum class spvYCbCrRange");
6124 			begin_scope();
6125 			statement("itu_full = 0,");
6126 			statement("itu_narrow");
6127 			end_scope_decl();
6128 			statement("");
6129 			statement("struct spvComponentBits");
6130 			begin_scope();
6131 			statement("constexpr explicit spvComponentBits(int v) thread : value(v) {}");
6132 			statement("uchar value : 6;");
6133 			end_scope_decl();
6134 			statement("// A class corresponding to metal::sampler which holds sampler");
6135 			statement("// Y'CbCr conversion info.");
6136 			statement("struct spvYCbCrSampler");
6137 			begin_scope();
6138 			statement("constexpr spvYCbCrSampler() thread : val(build()) {}");
6139 			statement("template<typename... Ts>");
6140 			statement("constexpr spvYCbCrSampler(Ts... t) thread : val(build(t...)) {}");
6141 			statement("constexpr spvYCbCrSampler(const thread spvYCbCrSampler& s) thread = default;");
6142 			statement("");
6143 			statement("spvFormatResolution get_resolution() const thread");
6144 			begin_scope();
6145 			statement("return spvFormatResolution((val & resolution_mask) >> resolution_base);");
6146 			end_scope();
6147 			statement("spvChromaFilter get_chroma_filter() const thread");
6148 			begin_scope();
6149 			statement("return spvChromaFilter((val & chroma_filter_mask) >> chroma_filter_base);");
6150 			end_scope();
6151 			statement("spvXChromaLocation get_x_chroma_offset() const thread");
6152 			begin_scope();
6153 			statement("return spvXChromaLocation((val & x_chroma_off_mask) >> x_chroma_off_base);");
6154 			end_scope();
6155 			statement("spvYChromaLocation get_y_chroma_offset() const thread");
6156 			begin_scope();
6157 			statement("return spvYChromaLocation((val & y_chroma_off_mask) >> y_chroma_off_base);");
6158 			end_scope();
6159 			statement("spvYCbCrModelConversion get_ycbcr_model() const thread");
6160 			begin_scope();
6161 			statement("return spvYCbCrModelConversion((val & ycbcr_model_mask) >> ycbcr_model_base);");
6162 			end_scope();
6163 			statement("spvYCbCrRange get_ycbcr_range() const thread");
6164 			begin_scope();
6165 			statement("return spvYCbCrRange((val & ycbcr_range_mask) >> ycbcr_range_base);");
6166 			end_scope();
6167 			statement("int get_bpc() const thread { return (val & bpc_mask) >> bpc_base; }");
6168 			statement("");
6169 			statement("private:");
6170 			statement("ushort val;");
6171 			statement("");
6172 			statement("constexpr static constant ushort resolution_bits = 2;");
6173 			statement("constexpr static constant ushort chroma_filter_bits = 2;");
6174 			statement("constexpr static constant ushort x_chroma_off_bit = 1;");
6175 			statement("constexpr static constant ushort y_chroma_off_bit = 1;");
6176 			statement("constexpr static constant ushort ycbcr_model_bits = 3;");
6177 			statement("constexpr static constant ushort ycbcr_range_bit = 1;");
6178 			statement("constexpr static constant ushort bpc_bits = 6;");
6179 			statement("");
6180 			statement("constexpr static constant ushort resolution_base = 0;");
6181 			statement("constexpr static constant ushort chroma_filter_base = 2;");
6182 			statement("constexpr static constant ushort x_chroma_off_base = 4;");
6183 			statement("constexpr static constant ushort y_chroma_off_base = 5;");
6184 			statement("constexpr static constant ushort ycbcr_model_base = 6;");
6185 			statement("constexpr static constant ushort ycbcr_range_base = 9;");
6186 			statement("constexpr static constant ushort bpc_base = 10;");
6187 			statement("");
6188 			statement(
6189 			    "constexpr static constant ushort resolution_mask = ((1 << resolution_bits) - 1) << resolution_base;");
6190 			statement("constexpr static constant ushort chroma_filter_mask = ((1 << chroma_filter_bits) - 1) << "
6191 			          "chroma_filter_base;");
6192 			statement("constexpr static constant ushort x_chroma_off_mask = ((1 << x_chroma_off_bit) - 1) << "
6193 			          "x_chroma_off_base;");
6194 			statement("constexpr static constant ushort y_chroma_off_mask = ((1 << y_chroma_off_bit) - 1) << "
6195 			          "y_chroma_off_base;");
6196 			statement("constexpr static constant ushort ycbcr_model_mask = ((1 << ycbcr_model_bits) - 1) << "
6197 			          "ycbcr_model_base;");
6198 			statement("constexpr static constant ushort ycbcr_range_mask = ((1 << ycbcr_range_bit) - 1) << "
6199 			          "ycbcr_range_base;");
6200 			statement("constexpr static constant ushort bpc_mask = ((1 << bpc_bits) - 1) << bpc_base;");
6201 			statement("");
6202 			statement("static constexpr ushort build()");
6203 			begin_scope();
6204 			statement("return 0;");
6205 			end_scope();
6206 			statement("");
6207 			statement("template<typename... Ts>");
6208 			statement("static constexpr ushort build(spvFormatResolution res, Ts... t)");
6209 			begin_scope();
6210 			statement("return (ushort(res) << resolution_base) | (build(t...) & ~resolution_mask);");
6211 			end_scope();
6212 			statement("");
6213 			statement("template<typename... Ts>");
6214 			statement("static constexpr ushort build(spvChromaFilter filt, Ts... t)");
6215 			begin_scope();
6216 			statement("return (ushort(filt) << chroma_filter_base) | (build(t...) & ~chroma_filter_mask);");
6217 			end_scope();
6218 			statement("");
6219 			statement("template<typename... Ts>");
6220 			statement("static constexpr ushort build(spvXChromaLocation loc, Ts... t)");
6221 			begin_scope();
6222 			statement("return (ushort(loc) << x_chroma_off_base) | (build(t...) & ~x_chroma_off_mask);");
6223 			end_scope();
6224 			statement("");
6225 			statement("template<typename... Ts>");
6226 			statement("static constexpr ushort build(spvYChromaLocation loc, Ts... t)");
6227 			begin_scope();
6228 			statement("return (ushort(loc) << y_chroma_off_base) | (build(t...) & ~y_chroma_off_mask);");
6229 			end_scope();
6230 			statement("");
6231 			statement("template<typename... Ts>");
6232 			statement("static constexpr ushort build(spvYCbCrModelConversion model, Ts... t)");
6233 			begin_scope();
6234 			statement("return (ushort(model) << ycbcr_model_base) | (build(t...) & ~ycbcr_model_mask);");
6235 			end_scope();
6236 			statement("");
6237 			statement("template<typename... Ts>");
6238 			statement("static constexpr ushort build(spvYCbCrRange range, Ts... t)");
6239 			begin_scope();
6240 			statement("return (ushort(range) << ycbcr_range_base) | (build(t...) & ~ycbcr_range_mask);");
6241 			end_scope();
6242 			statement("");
6243 			statement("template<typename... Ts>");
6244 			statement("static constexpr ushort build(spvComponentBits bpc, Ts... t)");
6245 			begin_scope();
6246 			statement("return (ushort(bpc.value) << bpc_base) | (build(t...) & ~bpc_mask);");
6247 			end_scope();
6248 			end_scope_decl();
6249 			statement("");
6250 			statement("// A class which can hold up to three textures and a sampler, including");
6251 			statement("// Y'CbCr conversion info, used to pass combined image-samplers");
6252 			statement("// dynamically to functions.");
6253 			statement("template<typename T>");
6254 			statement("struct spvDynamicImageSampler");
6255 			begin_scope();
6256 			statement("texture2d<T> plane0;");
6257 			statement("texture2d<T> plane1;");
6258 			statement("texture2d<T> plane2;");
6259 			statement("sampler samp;");
6260 			statement("spvYCbCrSampler ycbcr_samp;");
6261 			statement("uint swizzle = 0;");
6262 			statement("");
6263 			if (msl_options.swizzle_texture_samples)
6264 			{
6265 				statement("constexpr spvDynamicImageSampler(texture2d<T> tex, sampler samp, uint sw) thread :");
6266 				statement("    plane0(tex), samp(samp), swizzle(sw) {}");
6267 			}
6268 			else
6269 			{
6270 				statement("constexpr spvDynamicImageSampler(texture2d<T> tex, sampler samp) thread :");
6271 				statement("    plane0(tex), samp(samp) {}");
6272 			}
6273 			statement("constexpr spvDynamicImageSampler(texture2d<T> tex, sampler samp, spvYCbCrSampler ycbcr_samp, "
6274 			          "uint sw) thread :");
6275 			statement("    plane0(tex), samp(samp), ycbcr_samp(ycbcr_samp), swizzle(sw) {}");
6276 			statement("constexpr spvDynamicImageSampler(texture2d<T> plane0, texture2d<T> plane1,");
6277 			statement("                                 sampler samp, spvYCbCrSampler ycbcr_samp, uint sw) thread :");
6278 			statement("    plane0(plane0), plane1(plane1), samp(samp), ycbcr_samp(ycbcr_samp), swizzle(sw) {}");
6279 			statement(
6280 			    "constexpr spvDynamicImageSampler(texture2d<T> plane0, texture2d<T> plane1, texture2d<T> plane2,");
6281 			statement("                                 sampler samp, spvYCbCrSampler ycbcr_samp, uint sw) thread :");
6282 			statement("    plane0(plane0), plane1(plane1), plane2(plane2), samp(samp), ycbcr_samp(ycbcr_samp), "
6283 			          "swizzle(sw) {}");
6284 			statement("");
6285 			// XXX This is really hard to follow... I've left comments to make it a bit easier.
6286 			statement("template<typename... LodOptions>");
6287 			statement("vec<T, 4> do_sample(float2 coord, LodOptions... options) const thread");
6288 			begin_scope();
6289 			statement("if (!is_null_texture(plane1))");
6290 			begin_scope();
6291 			statement("if (ycbcr_samp.get_resolution() == spvFormatResolution::_444 ||");
6292 			statement("    ycbcr_samp.get_chroma_filter() == spvChromaFilter::nearest)");
6293 			begin_scope();
6294 			statement("if (!is_null_texture(plane2))");
6295 			statement("    return spvChromaReconstructNearest(plane0, plane1, plane2, samp, coord,");
6296 			statement("                                       spvForward<LodOptions>(options)...);");
6297 			statement(
6298 			    "return spvChromaReconstructNearest(plane0, plane1, samp, coord, spvForward<LodOptions>(options)...);");
6299 			end_scope(); // if (resolution == 422 || chroma_filter == nearest)
6300 			statement("switch (ycbcr_samp.get_resolution())");
6301 			begin_scope();
6302 			statement("case spvFormatResolution::_444: break;");
6303 			statement("case spvFormatResolution::_422:");
6304 			begin_scope();
6305 			statement("switch (ycbcr_samp.get_x_chroma_offset())");
6306 			begin_scope();
6307 			statement("case spvXChromaLocation::cosited_even:");
6308 			statement("    if (!is_null_texture(plane2))");
6309 			statement("        return spvChromaReconstructLinear422CositedEven(");
6310 			statement("            plane0, plane1, plane2, samp,");
6311 			statement("            coord, spvForward<LodOptions>(options)...);");
6312 			statement("    return spvChromaReconstructLinear422CositedEven(");
6313 			statement("        plane0, plane1, samp, coord,");
6314 			statement("        spvForward<LodOptions>(options)...);");
6315 			statement("case spvXChromaLocation::midpoint:");
6316 			statement("    if (!is_null_texture(plane2))");
6317 			statement("        return spvChromaReconstructLinear422Midpoint(");
6318 			statement("            plane0, plane1, plane2, samp,");
6319 			statement("            coord, spvForward<LodOptions>(options)...);");
6320 			statement("    return spvChromaReconstructLinear422Midpoint(");
6321 			statement("        plane0, plane1, samp, coord,");
6322 			statement("        spvForward<LodOptions>(options)...);");
6323 			end_scope(); // switch (x_chroma_offset)
6324 			end_scope(); // case 422:
6325 			statement("case spvFormatResolution::_420:");
6326 			begin_scope();
6327 			statement("switch (ycbcr_samp.get_x_chroma_offset())");
6328 			begin_scope();
6329 			statement("case spvXChromaLocation::cosited_even:");
6330 			begin_scope();
6331 			statement("switch (ycbcr_samp.get_y_chroma_offset())");
6332 			begin_scope();
6333 			statement("case spvYChromaLocation::cosited_even:");
6334 			statement("    if (!is_null_texture(plane2))");
6335 			statement("        return spvChromaReconstructLinear420XCositedEvenYCositedEven(");
6336 			statement("            plane0, plane1, plane2, samp,");
6337 			statement("            coord, spvForward<LodOptions>(options)...);");
6338 			statement("    return spvChromaReconstructLinear420XCositedEvenYCositedEven(");
6339 			statement("        plane0, plane1, samp, coord,");
6340 			statement("        spvForward<LodOptions>(options)...);");
6341 			statement("case spvYChromaLocation::midpoint:");
6342 			statement("    if (!is_null_texture(plane2))");
6343 			statement("        return spvChromaReconstructLinear420XCositedEvenYMidpoint(");
6344 			statement("            plane0, plane1, plane2, samp,");
6345 			statement("            coord, spvForward<LodOptions>(options)...);");
6346 			statement("    return spvChromaReconstructLinear420XCositedEvenYMidpoint(");
6347 			statement("        plane0, plane1, samp, coord,");
6348 			statement("        spvForward<LodOptions>(options)...);");
6349 			end_scope(); // switch (y_chroma_offset)
6350 			end_scope(); // case x::cosited_even:
6351 			statement("case spvXChromaLocation::midpoint:");
6352 			begin_scope();
6353 			statement("switch (ycbcr_samp.get_y_chroma_offset())");
6354 			begin_scope();
6355 			statement("case spvYChromaLocation::cosited_even:");
6356 			statement("    if (!is_null_texture(plane2))");
6357 			statement("        return spvChromaReconstructLinear420XMidpointYCositedEven(");
6358 			statement("            plane0, plane1, plane2, samp,");
6359 			statement("            coord, spvForward<LodOptions>(options)...);");
6360 			statement("    return spvChromaReconstructLinear420XMidpointYCositedEven(");
6361 			statement("        plane0, plane1, samp, coord,");
6362 			statement("        spvForward<LodOptions>(options)...);");
6363 			statement("case spvYChromaLocation::midpoint:");
6364 			statement("    if (!is_null_texture(plane2))");
6365 			statement("        return spvChromaReconstructLinear420XMidpointYMidpoint(");
6366 			statement("            plane0, plane1, plane2, samp,");
6367 			statement("            coord, spvForward<LodOptions>(options)...);");
6368 			statement("    return spvChromaReconstructLinear420XMidpointYMidpoint(");
6369 			statement("        plane0, plane1, samp, coord,");
6370 			statement("        spvForward<LodOptions>(options)...);");
6371 			end_scope(); // switch (y_chroma_offset)
6372 			end_scope(); // case x::midpoint
6373 			end_scope(); // switch (x_chroma_offset)
6374 			end_scope(); // case 420:
6375 			end_scope(); // switch (resolution)
6376 			end_scope(); // if (multiplanar)
6377 			statement("return plane0.sample(samp, coord, spvForward<LodOptions>(options)...);");
6378 			end_scope(); // do_sample()
6379 			statement("template <typename... LodOptions>");
6380 			statement("vec<T, 4> sample(float2 coord, LodOptions... options) const thread");
6381 			begin_scope();
6382 			statement(
6383 			    "vec<T, 4> s = spvTextureSwizzle(do_sample(coord, spvForward<LodOptions>(options)...), swizzle);");
6384 			statement("if (ycbcr_samp.get_ycbcr_model() == spvYCbCrModelConversion::rgb_identity)");
6385 			statement("    return s;");
6386 			statement("");
6387 			statement("switch (ycbcr_samp.get_ycbcr_range())");
6388 			begin_scope();
6389 			statement("case spvYCbCrRange::itu_full:");
6390 			statement("    s = spvExpandITUFullRange(s, ycbcr_samp.get_bpc());");
6391 			statement("    break;");
6392 			statement("case spvYCbCrRange::itu_narrow:");
6393 			statement("    s = spvExpandITUNarrowRange(s, ycbcr_samp.get_bpc());");
6394 			statement("    break;");
6395 			end_scope();
6396 			statement("");
6397 			statement("switch (ycbcr_samp.get_ycbcr_model())");
6398 			begin_scope();
6399 			statement("case spvYCbCrModelConversion::rgb_identity:"); // Silence Clang warning
6400 			statement("case spvYCbCrModelConversion::ycbcr_identity:");
6401 			statement("    return s;");
6402 			statement("case spvYCbCrModelConversion::ycbcr_bt_709:");
6403 			statement("    return spvConvertYCbCrBT709(s);");
6404 			statement("case spvYCbCrModelConversion::ycbcr_bt_601:");
6405 			statement("    return spvConvertYCbCrBT601(s);");
6406 			statement("case spvYCbCrModelConversion::ycbcr_bt_2020:");
6407 			statement("    return spvConvertYCbCrBT2020(s);");
6408 			end_scope();
6409 			end_scope();
6410 			statement("");
6411 			// Sampler Y'CbCr conversion forbids offsets.
6412 			statement("vec<T, 4> sample(float2 coord, int2 offset) const thread");
6413 			begin_scope();
6414 			if (msl_options.swizzle_texture_samples)
6415 				statement("return spvTextureSwizzle(plane0.sample(samp, coord, offset), swizzle);");
6416 			else
6417 				statement("return plane0.sample(samp, coord, offset);");
6418 			end_scope();
6419 			statement("template<typename lod_options>");
6420 			statement("vec<T, 4> sample(float2 coord, lod_options options, int2 offset) const thread");
6421 			begin_scope();
6422 			if (msl_options.swizzle_texture_samples)
6423 				statement("return spvTextureSwizzle(plane0.sample(samp, coord, options, offset), swizzle);");
6424 			else
6425 				statement("return plane0.sample(samp, coord, options, offset);");
6426 			end_scope();
6427 			statement("#if __HAVE_MIN_LOD_CLAMP__");
6428 			statement("vec<T, 4> sample(float2 coord, bias b, min_lod_clamp min_lod, int2 offset) const thread");
6429 			begin_scope();
6430 			statement("return plane0.sample(samp, coord, b, min_lod, offset);");
6431 			end_scope();
6432 			statement(
6433 			    "vec<T, 4> sample(float2 coord, gradient2d grad, min_lod_clamp min_lod, int2 offset) const thread");
6434 			begin_scope();
6435 			statement("return plane0.sample(samp, coord, grad, min_lod, offset);");
6436 			end_scope();
6437 			statement("#endif");
6438 			statement("");
6439 			// Y'CbCr conversion forbids all operations but sampling.
6440 			statement("vec<T, 4> read(uint2 coord, uint lod = 0) const thread");
6441 			begin_scope();
6442 			statement("return plane0.read(coord, lod);");
6443 			end_scope();
6444 			statement("");
6445 			statement("vec<T, 4> gather(float2 coord, int2 offset = int2(0), component c = component::x) const thread");
6446 			begin_scope();
6447 			if (msl_options.swizzle_texture_samples)
6448 				statement("return spvGatherSwizzle(plane0, samp, swizzle, c, coord, offset);");
6449 			else
6450 				statement("return plane0.gather(samp, coord, offset, c);");
6451 			end_scope();
6452 			end_scope_decl();
6453 			statement("");
6454 
6455 		default:
6456 			break;
6457 		}
6458 	}
6459 }
6460 
inject_top_level_storage_qualifier(const string & expr,const string & qualifier)6461 static string inject_top_level_storage_qualifier(const string &expr, const string &qualifier)
6462 {
6463 	// Easier to do this through text munging since the qualifier does not exist in the type system at all,
6464 	// and plumbing in all that information is not very helpful.
6465 	size_t last_reference = expr.find_last_of('&');
6466 	size_t last_pointer = expr.find_last_of('*');
6467 	size_t last_significant = string::npos;
6468 
6469 	if (last_reference == string::npos)
6470 		last_significant = last_pointer;
6471 	else if (last_pointer == string::npos)
6472 		last_significant = last_reference;
6473 	else
6474 		last_significant = std::max(last_reference, last_pointer);
6475 
6476 	if (last_significant == string::npos)
6477 		return join(qualifier, " ", expr);
6478 	else
6479 	{
6480 		return join(expr.substr(0, last_significant + 1), " ",
6481 		            qualifier, expr.substr(last_significant + 1, string::npos));
6482 	}
6483 }
6484 
6485 // Undefined global memory is not allowed in MSL.
6486 // Declare constant and init to zeros. Use {}, as global constructors can break Metal.
declare_undefined_values()6487 void CompilerMSL::declare_undefined_values()
6488 {
6489 	bool emitted = false;
6490 	ir.for_each_typed_id<SPIRUndef>([&](uint32_t, SPIRUndef &undef) {
6491 		auto &type = this->get<SPIRType>(undef.basetype);
6492 		// OpUndef can be void for some reason ...
6493 		if (type.basetype == SPIRType::Void)
6494 			return;
6495 
6496 		statement(inject_top_level_storage_qualifier(
6497 				variable_decl(type, to_name(undef.self), undef.self),
6498 				"constant"),
6499 		          " = {};");
6500 		emitted = true;
6501 	});
6502 
6503 	if (emitted)
6504 		statement("");
6505 }
6506 
declare_constant_arrays()6507 void CompilerMSL::declare_constant_arrays()
6508 {
6509 	bool fully_inlined = ir.ids_for_type[TypeFunction].size() == 1;
6510 
6511 	// MSL cannot declare arrays inline (except when declaring a variable), so we must move them out to
6512 	// global constants directly, so we are able to use constants as variable expressions.
6513 	bool emitted = false;
6514 
6515 	ir.for_each_typed_id<SPIRConstant>([&](uint32_t, SPIRConstant &c) {
6516 		if (c.specialization)
6517 			return;
6518 
6519 		auto &type = this->get<SPIRType>(c.constant_type);
6520 		// Constant arrays of non-primitive types (i.e. matrices) won't link properly into Metal libraries.
6521 		// FIXME: However, hoisting constants to main() means we need to pass down constant arrays to leaf functions if they are used there.
6522 		// If there are multiple functions in the module, drop this case to avoid breaking use cases which do not need to
6523 		// link into Metal libraries. This is hacky.
6524 		if (!type.array.empty() && (!fully_inlined || is_scalar(type) || is_vector(type)))
6525 		{
6526 			auto name = to_name(c.self);
6527 			statement(inject_top_level_storage_qualifier(variable_decl(type, name), "constant"),
6528 			          " = ", constant_expression(c), ";");
6529 			emitted = true;
6530 		}
6531 	});
6532 
6533 	if (emitted)
6534 		statement("");
6535 }
6536 
6537 // Constant arrays of non-primitive types (i.e. matrices) won't link properly into Metal libraries
declare_complex_constant_arrays()6538 void CompilerMSL::declare_complex_constant_arrays()
6539 {
6540 	// If we do not have a fully inlined module, we did not opt in to
6541 	// declaring constant arrays of complex types. See CompilerMSL::declare_constant_arrays().
6542 	bool fully_inlined = ir.ids_for_type[TypeFunction].size() == 1;
6543 	if (!fully_inlined)
6544 		return;
6545 
6546 	// MSL cannot declare arrays inline (except when declaring a variable), so we must move them out to
6547 	// global constants directly, so we are able to use constants as variable expressions.
6548 	bool emitted = false;
6549 
6550 	ir.for_each_typed_id<SPIRConstant>([&](uint32_t, SPIRConstant &c) {
6551 		if (c.specialization)
6552 			return;
6553 
6554 		auto &type = this->get<SPIRType>(c.constant_type);
6555 		if (!type.array.empty() && !(is_scalar(type) || is_vector(type)))
6556 		{
6557 			auto name = to_name(c.self);
6558 			statement("", variable_decl(type, name), " = ", constant_expression(c), ";");
6559 			emitted = true;
6560 		}
6561 	});
6562 
6563 	if (emitted)
6564 		statement("");
6565 }
6566 
emit_resources()6567 void CompilerMSL::emit_resources()
6568 {
6569 	declare_constant_arrays();
6570 	declare_undefined_values();
6571 
6572 	// Emit the special [[stage_in]] and [[stage_out]] interface blocks which we created.
6573 	emit_interface_block(stage_out_var_id);
6574 	emit_interface_block(patch_stage_out_var_id);
6575 	emit_interface_block(stage_in_var_id);
6576 	emit_interface_block(patch_stage_in_var_id);
6577 }
6578 
6579 // Emit declarations for the specialization Metal function constants
emit_specialization_constants_and_structs()6580 void CompilerMSL::emit_specialization_constants_and_structs()
6581 {
6582 	SpecializationConstant wg_x, wg_y, wg_z;
6583 	ID workgroup_size_id = get_work_group_size_specialization_constants(wg_x, wg_y, wg_z);
6584 	bool emitted = false;
6585 
6586 	unordered_set<uint32_t> declared_structs;
6587 	unordered_set<uint32_t> aligned_structs;
6588 
6589 	// First, we need to deal with scalar block layout.
6590 	// It is possible that a struct may have to be placed at an alignment which does not match the innate alignment of the struct itself.
6591 	// In that case, if such a case exists for a struct, we must force that all elements of the struct become packed_ types.
6592 	// This makes the struct alignment as small as physically possible.
6593 	// When we actually align the struct later, we can insert padding as necessary to make the packed members behave like normally aligned types.
6594 	ir.for_each_typed_id<SPIRType>([&](uint32_t type_id, const SPIRType &type) {
6595 		if (type.basetype == SPIRType::Struct &&
6596 		    has_extended_decoration(type_id, SPIRVCrossDecorationBufferBlockRepacked))
6597 			mark_scalar_layout_structs(type);
6598 	});
6599 
6600 	bool builtin_block_type_is_required = false;
6601 	// Very special case. If gl_PerVertex is initialized as an array (tessellation)
6602 	// we have to potentially emit the gl_PerVertex struct type so that we can emit a constant LUT.
6603 	ir.for_each_typed_id<SPIRConstant>([&](uint32_t, SPIRConstant &c) {
6604 		auto &type = this->get<SPIRType>(c.constant_type);
6605 		if (is_array(type) && has_decoration(type.self, DecorationBlock) && is_builtin_type(type))
6606 			builtin_block_type_is_required = true;
6607 	});
6608 
6609 	// Very particular use of the soft loop lock.
6610 	// align_struct may need to create custom types on the fly, but we don't care about
6611 	// these types for purpose of iterating over them in ir.ids_for_type and friends.
6612 	auto loop_lock = ir.create_loop_soft_lock();
6613 
6614 	for (auto &id_ : ir.ids_for_constant_or_type)
6615 	{
6616 		auto &id = ir.ids[id_];
6617 
6618 		if (id.get_type() == TypeConstant)
6619 		{
6620 			auto &c = id.get<SPIRConstant>();
6621 
6622 			if (c.self == workgroup_size_id)
6623 			{
6624 				// TODO: This can be expressed as a [[threads_per_threadgroup]] input semantic, but we need to know
6625 				// the work group size at compile time in SPIR-V, and [[threads_per_threadgroup]] would need to be passed around as a global.
6626 				// The work group size may be a specialization constant.
6627 				statement("constant uint3 ", builtin_to_glsl(BuiltInWorkgroupSize, StorageClassWorkgroup),
6628 				          " [[maybe_unused]] = ", constant_expression(get<SPIRConstant>(workgroup_size_id)), ";");
6629 				emitted = true;
6630 			}
6631 			else if (c.specialization)
6632 			{
6633 				auto &type = get<SPIRType>(c.constant_type);
6634 				string sc_type_name = type_to_glsl(type);
6635 				string sc_name = to_name(c.self);
6636 				string sc_tmp_name = sc_name + "_tmp";
6637 
6638 				// Function constants are only supported in MSL 1.2 and later.
6639 				// If we don't support it just declare the "default" directly.
6640 				// This "default" value can be overridden to the true specialization constant by the API user.
6641 				// Specialization constants which are used as array length expressions cannot be function constants in MSL,
6642 				// so just fall back to macros.
6643 				if (msl_options.supports_msl_version(1, 2) && has_decoration(c.self, DecorationSpecId) &&
6644 				    !c.is_used_as_array_length)
6645 				{
6646 					uint32_t constant_id = get_decoration(c.self, DecorationSpecId);
6647 					// Only scalar, non-composite values can be function constants.
6648 					statement("constant ", sc_type_name, " ", sc_tmp_name, " [[function_constant(", constant_id,
6649 					          ")]];");
6650 					statement("constant ", sc_type_name, " ", sc_name, " = is_function_constant_defined(", sc_tmp_name,
6651 					          ") ? ", sc_tmp_name, " : ", constant_expression(c), ";");
6652 				}
6653 				else if (has_decoration(c.self, DecorationSpecId))
6654 				{
6655 					// Fallback to macro overrides.
6656 					c.specialization_constant_macro_name =
6657 					    constant_value_macro_name(get_decoration(c.self, DecorationSpecId));
6658 
6659 					statement("#ifndef ", c.specialization_constant_macro_name);
6660 					statement("#define ", c.specialization_constant_macro_name, " ", constant_expression(c));
6661 					statement("#endif");
6662 					statement("constant ", sc_type_name, " ", sc_name, " = ", c.specialization_constant_macro_name,
6663 					          ";");
6664 				}
6665 				else
6666 				{
6667 					// Composite specialization constants must be built from other specialization constants.
6668 					statement("constant ", sc_type_name, " ", sc_name, " = ", constant_expression(c), ";");
6669 				}
6670 				emitted = true;
6671 			}
6672 		}
6673 		else if (id.get_type() == TypeConstantOp)
6674 		{
6675 			auto &c = id.get<SPIRConstantOp>();
6676 			auto &type = get<SPIRType>(c.basetype);
6677 			auto name = to_name(c.self);
6678 			statement("constant ", variable_decl(type, name), " = ", constant_op_expression(c), ";");
6679 			emitted = true;
6680 		}
6681 		else if (id.get_type() == TypeType)
6682 		{
6683 			// Output non-builtin interface structs. These include local function structs
6684 			// and structs nested within uniform and read-write buffers.
6685 			auto &type = id.get<SPIRType>();
6686 			TypeID type_id = type.self;
6687 
6688 			bool is_struct = (type.basetype == SPIRType::Struct) && type.array.empty() && !type.pointer;
6689 			bool is_block =
6690 			    has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock);
6691 
6692 			bool is_builtin_block = is_block && is_builtin_type(type);
6693 			bool is_declarable_struct = is_struct && (!is_builtin_block || builtin_block_type_is_required);
6694 
6695 			// We'll declare this later.
6696 			if (stage_out_var_id && get_stage_out_struct_type().self == type_id)
6697 				is_declarable_struct = false;
6698 			if (patch_stage_out_var_id && get_patch_stage_out_struct_type().self == type_id)
6699 				is_declarable_struct = false;
6700 			if (stage_in_var_id && get_stage_in_struct_type().self == type_id)
6701 				is_declarable_struct = false;
6702 			if (patch_stage_in_var_id && get_patch_stage_in_struct_type().self == type_id)
6703 				is_declarable_struct = false;
6704 
6705 			// Special case. Declare builtin struct anyways if we need to emit a threadgroup version of it.
6706 			if (stage_out_masked_builtin_type_id == type_id)
6707 				is_declarable_struct = true;
6708 
6709 			// Align and emit declarable structs...but avoid declaring each more than once.
6710 			if (is_declarable_struct && declared_structs.count(type_id) == 0)
6711 			{
6712 				if (emitted)
6713 					statement("");
6714 				emitted = false;
6715 
6716 				declared_structs.insert(type_id);
6717 
6718 				if (has_extended_decoration(type_id, SPIRVCrossDecorationBufferBlockRepacked))
6719 					align_struct(type, aligned_structs);
6720 
6721 				// Make sure we declare the underlying struct type, and not the "decorated" type with pointers, etc.
6722 				emit_struct(get<SPIRType>(type_id));
6723 			}
6724 		}
6725 	}
6726 
6727 	if (emitted)
6728 		statement("");
6729 }
6730 
emit_binary_unord_op(uint32_t result_type,uint32_t result_id,uint32_t op0,uint32_t op1,const char * op)6731 void CompilerMSL::emit_binary_unord_op(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1,
6732                                        const char *op)
6733 {
6734 	bool forward = should_forward(op0) && should_forward(op1);
6735 	emit_op(result_type, result_id,
6736 	        join("(isunordered(", to_enclosed_unpacked_expression(op0), ", ", to_enclosed_unpacked_expression(op1),
6737 	             ") || ", to_enclosed_unpacked_expression(op0), " ", op, " ", to_enclosed_unpacked_expression(op1),
6738 	             ")"),
6739 	        forward);
6740 
6741 	inherit_expression_dependencies(result_id, op0);
6742 	inherit_expression_dependencies(result_id, op1);
6743 }
6744 
emit_tessellation_io_load(uint32_t result_type_id,uint32_t id,uint32_t ptr)6745 bool CompilerMSL::emit_tessellation_io_load(uint32_t result_type_id, uint32_t id, uint32_t ptr)
6746 {
6747 	auto &ptr_type = expression_type(ptr);
6748 	auto &result_type = get<SPIRType>(result_type_id);
6749 	if (ptr_type.storage != StorageClassInput && ptr_type.storage != StorageClassOutput)
6750 		return false;
6751 	if (ptr_type.storage == StorageClassOutput && get_execution_model() == ExecutionModelTessellationEvaluation)
6752 		return false;
6753 
6754 	if (has_decoration(ptr, DecorationPatch))
6755 		return false;
6756 	bool ptr_is_io_variable = ir.ids[ptr].get_type() == TypeVariable;
6757 
6758 	bool flattened_io = variable_storage_requires_stage_io(ptr_type.storage);
6759 
6760 	bool flat_data_type = flattened_io &&
6761 	                      (is_matrix(result_type) || is_array(result_type) || result_type.basetype == SPIRType::Struct);
6762 
6763 	// Edge case, even with multi-patch workgroups, we still need to unroll load
6764 	// if we're loading control points directly.
6765 	if (ptr_is_io_variable && is_array(result_type))
6766 		flat_data_type = true;
6767 
6768 	if (!flat_data_type)
6769 		return false;
6770 
6771 	// Now, we must unflatten a composite type and take care of interleaving array access with gl_in/gl_out.
6772 	// Lots of painful code duplication since we *really* should not unroll these kinds of loads in entry point fixup
6773 	// unless we're forced to do this when the code is emitting inoptimal OpLoads.
6774 	string expr;
6775 
6776 	uint32_t interface_index = get_extended_decoration(ptr, SPIRVCrossDecorationInterfaceMemberIndex);
6777 	auto *var = maybe_get_backing_variable(ptr);
6778 	auto &expr_type = get_pointee_type(ptr_type.self);
6779 
6780 	const auto &iface_type = expression_type(stage_in_ptr_var_id);
6781 
6782 	if (!flattened_io)
6783 	{
6784 		// Simplest case for multi-patch workgroups, just unroll array as-is.
6785 		if (interface_index == uint32_t(-1))
6786 			return false;
6787 
6788 		expr += type_to_glsl(result_type) + "({ ";
6789 		uint32_t num_control_points = to_array_size_literal(result_type, uint32_t(result_type.array.size()) - 1);
6790 
6791 		for (uint32_t i = 0; i < num_control_points; i++)
6792 		{
6793 			const uint32_t indices[2] = { i, interface_index };
6794 			AccessChainMeta meta;
6795 			expr += access_chain_internal(stage_in_ptr_var_id, indices, 2,
6796 			                              ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
6797 			if (i + 1 < num_control_points)
6798 				expr += ", ";
6799 		}
6800 		expr += " })";
6801 	}
6802 	else if (result_type.array.size() > 2)
6803 	{
6804 		SPIRV_CROSS_THROW("Cannot load tessellation IO variables with more than 2 dimensions.");
6805 	}
6806 	else if (result_type.array.size() == 2)
6807 	{
6808 		if (!ptr_is_io_variable)
6809 			SPIRV_CROSS_THROW("Loading an array-of-array must be loaded directly from an IO variable.");
6810 		if (interface_index == uint32_t(-1))
6811 			SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
6812 		if (result_type.basetype == SPIRType::Struct || is_matrix(result_type))
6813 			SPIRV_CROSS_THROW("Cannot load array-of-array of composite type in tessellation IO.");
6814 
6815 		expr += type_to_glsl(result_type) + "({ ";
6816 		uint32_t num_control_points = to_array_size_literal(result_type, 1);
6817 		uint32_t base_interface_index = interface_index;
6818 
6819 		auto &sub_type = get<SPIRType>(result_type.parent_type);
6820 
6821 		for (uint32_t i = 0; i < num_control_points; i++)
6822 		{
6823 			expr += type_to_glsl(sub_type) + "({ ";
6824 			interface_index = base_interface_index;
6825 			uint32_t array_size = to_array_size_literal(result_type, 0);
6826 			for (uint32_t j = 0; j < array_size; j++, interface_index++)
6827 			{
6828 				const uint32_t indices[2] = { i, interface_index };
6829 
6830 				AccessChainMeta meta;
6831 				expr += access_chain_internal(stage_in_ptr_var_id, indices, 2,
6832 				                              ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
6833 				if (!is_matrix(sub_type) && sub_type.basetype != SPIRType::Struct &&
6834 					expr_type.vecsize > sub_type.vecsize)
6835 					expr += vector_swizzle(sub_type.vecsize, 0);
6836 
6837 				if (j + 1 < array_size)
6838 					expr += ", ";
6839 			}
6840 			expr += " })";
6841 			if (i + 1 < num_control_points)
6842 				expr += ", ";
6843 		}
6844 		expr += " })";
6845 	}
6846 	else if (result_type.basetype == SPIRType::Struct)
6847 	{
6848 		bool is_array_of_struct = is_array(result_type);
6849 		if (is_array_of_struct && !ptr_is_io_variable)
6850 			SPIRV_CROSS_THROW("Loading array of struct from IO variable must come directly from IO variable.");
6851 
6852 		uint32_t num_control_points = 1;
6853 		if (is_array_of_struct)
6854 		{
6855 			num_control_points = to_array_size_literal(result_type, 0);
6856 			expr += type_to_glsl(result_type) + "({ ";
6857 		}
6858 
6859 		auto &struct_type = is_array_of_struct ? get<SPIRType>(result_type.parent_type) : result_type;
6860 		assert(struct_type.array.empty());
6861 
6862 		for (uint32_t i = 0; i < num_control_points; i++)
6863 		{
6864 			expr += type_to_glsl(struct_type) + "{ ";
6865 			for (uint32_t j = 0; j < uint32_t(struct_type.member_types.size()); j++)
6866 			{
6867 				// The base interface index is stored per variable for structs.
6868 				if (var)
6869 				{
6870 					interface_index =
6871 					    get_extended_member_decoration(var->self, j, SPIRVCrossDecorationInterfaceMemberIndex);
6872 				}
6873 
6874 				if (interface_index == uint32_t(-1))
6875 					SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
6876 
6877 				const auto &mbr_type = get<SPIRType>(struct_type.member_types[j]);
6878 				const auto &expr_mbr_type = get<SPIRType>(expr_type.member_types[j]);
6879 				if (is_matrix(mbr_type) && ptr_type.storage == StorageClassInput)
6880 				{
6881 					expr += type_to_glsl(mbr_type) + "(";
6882 					for (uint32_t k = 0; k < mbr_type.columns; k++, interface_index++)
6883 					{
6884 						if (is_array_of_struct)
6885 						{
6886 							const uint32_t indices[2] = { i, interface_index };
6887 							AccessChainMeta meta;
6888 							expr += access_chain_internal(
6889 									stage_in_ptr_var_id, indices, 2,
6890 									ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
6891 						}
6892 						else
6893 							expr += to_expression(ptr) + "." + to_member_name(iface_type, interface_index);
6894 						if (expr_mbr_type.vecsize > mbr_type.vecsize)
6895 							expr += vector_swizzle(mbr_type.vecsize, 0);
6896 
6897 						if (k + 1 < mbr_type.columns)
6898 							expr += ", ";
6899 					}
6900 					expr += ")";
6901 				}
6902 				else if (is_array(mbr_type))
6903 				{
6904 					expr += type_to_glsl(mbr_type) + "({ ";
6905 					uint32_t array_size = to_array_size_literal(mbr_type, 0);
6906 					for (uint32_t k = 0; k < array_size; k++, interface_index++)
6907 					{
6908 						if (is_array_of_struct)
6909 						{
6910 							const uint32_t indices[2] = { i, interface_index };
6911 							AccessChainMeta meta;
6912 							expr += access_chain_internal(
6913 									stage_in_ptr_var_id, indices, 2,
6914 									ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
6915 						}
6916 						else
6917 							expr += to_expression(ptr) + "." + to_member_name(iface_type, interface_index);
6918 						if (expr_mbr_type.vecsize > mbr_type.vecsize)
6919 							expr += vector_swizzle(mbr_type.vecsize, 0);
6920 
6921 						if (k + 1 < array_size)
6922 							expr += ", ";
6923 					}
6924 					expr += " })";
6925 				}
6926 				else
6927 				{
6928 					if (is_array_of_struct)
6929 					{
6930 						const uint32_t indices[2] = { i, interface_index };
6931 						AccessChainMeta meta;
6932 						expr += access_chain_internal(stage_in_ptr_var_id, indices, 2,
6933 						                              ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT,
6934 						                              &meta);
6935 					}
6936 					else
6937 						expr += to_expression(ptr) + "." + to_member_name(iface_type, interface_index);
6938 					if (expr_mbr_type.vecsize > mbr_type.vecsize)
6939 						expr += vector_swizzle(mbr_type.vecsize, 0);
6940 				}
6941 
6942 				if (j + 1 < struct_type.member_types.size())
6943 					expr += ", ";
6944 			}
6945 			expr += " }";
6946 			if (i + 1 < num_control_points)
6947 				expr += ", ";
6948 		}
6949 		if (is_array_of_struct)
6950 			expr += " })";
6951 	}
6952 	else if (is_matrix(result_type))
6953 	{
6954 		bool is_array_of_matrix = is_array(result_type);
6955 		if (is_array_of_matrix && !ptr_is_io_variable)
6956 			SPIRV_CROSS_THROW("Loading array of matrix from IO variable must come directly from IO variable.");
6957 		if (interface_index == uint32_t(-1))
6958 			SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
6959 
6960 		if (is_array_of_matrix)
6961 		{
6962 			// Loading a matrix from each control point.
6963 			uint32_t base_interface_index = interface_index;
6964 			uint32_t num_control_points = to_array_size_literal(result_type, 0);
6965 			expr += type_to_glsl(result_type) + "({ ";
6966 
6967 			auto &matrix_type = get_variable_element_type(get<SPIRVariable>(ptr));
6968 
6969 			for (uint32_t i = 0; i < num_control_points; i++)
6970 			{
6971 				interface_index = base_interface_index;
6972 				expr += type_to_glsl(matrix_type) + "(";
6973 				for (uint32_t j = 0; j < result_type.columns; j++, interface_index++)
6974 				{
6975 					const uint32_t indices[2] = { i, interface_index };
6976 
6977 					AccessChainMeta meta;
6978 					expr += access_chain_internal(stage_in_ptr_var_id, indices, 2,
6979 					                              ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
6980 					if (expr_type.vecsize > result_type.vecsize)
6981 						expr += vector_swizzle(result_type.vecsize, 0);
6982 					if (j + 1 < result_type.columns)
6983 						expr += ", ";
6984 				}
6985 				expr += ")";
6986 				if (i + 1 < num_control_points)
6987 					expr += ", ";
6988 			}
6989 
6990 			expr += " })";
6991 		}
6992 		else
6993 		{
6994 			expr += type_to_glsl(result_type) + "(";
6995 			for (uint32_t i = 0; i < result_type.columns; i++, interface_index++)
6996 			{
6997 				expr += to_expression(ptr) + "." + to_member_name(iface_type, interface_index);
6998 				if (expr_type.vecsize > result_type.vecsize)
6999 					expr += vector_swizzle(result_type.vecsize, 0);
7000 				if (i + 1 < result_type.columns)
7001 					expr += ", ";
7002 			}
7003 			expr += ")";
7004 		}
7005 	}
7006 	else if (ptr_is_io_variable)
7007 	{
7008 		assert(is_array(result_type));
7009 		assert(result_type.array.size() == 1);
7010 		if (interface_index == uint32_t(-1))
7011 			SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
7012 
7013 		// We're loading an array directly from a global variable.
7014 		// This means we're loading one member from each control point.
7015 		expr += type_to_glsl(result_type) + "({ ";
7016 		uint32_t num_control_points = to_array_size_literal(result_type, 0);
7017 
7018 		for (uint32_t i = 0; i < num_control_points; i++)
7019 		{
7020 			const uint32_t indices[2] = { i, interface_index };
7021 
7022 			AccessChainMeta meta;
7023 			expr += access_chain_internal(stage_in_ptr_var_id, indices, 2,
7024 			                              ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
7025 			if (expr_type.vecsize > result_type.vecsize)
7026 				expr += vector_swizzle(result_type.vecsize, 0);
7027 
7028 			if (i + 1 < num_control_points)
7029 				expr += ", ";
7030 		}
7031 		expr += " })";
7032 	}
7033 	else
7034 	{
7035 		// We're loading an array from a concrete control point.
7036 		assert(is_array(result_type));
7037 		assert(result_type.array.size() == 1);
7038 		if (interface_index == uint32_t(-1))
7039 			SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
7040 
7041 		expr += type_to_glsl(result_type) + "({ ";
7042 		uint32_t array_size = to_array_size_literal(result_type, 0);
7043 		for (uint32_t i = 0; i < array_size; i++, interface_index++)
7044 		{
7045 			expr += to_expression(ptr) + "." + to_member_name(iface_type, interface_index);
7046 			if (expr_type.vecsize > result_type.vecsize)
7047 				expr += vector_swizzle(result_type.vecsize, 0);
7048 			if (i + 1 < array_size)
7049 				expr += ", ";
7050 		}
7051 		expr += " })";
7052 	}
7053 
7054 	emit_op(result_type_id, id, expr, false);
7055 	register_read(id, ptr, false);
7056 	return true;
7057 }
7058 
emit_tessellation_access_chain(const uint32_t * ops,uint32_t length)7059 bool CompilerMSL::emit_tessellation_access_chain(const uint32_t *ops, uint32_t length)
7060 {
7061 	// If this is a per-vertex output, remap it to the I/O array buffer.
7062 
7063 	// Any object which did not go through IO flattening shenanigans will go there instead.
7064 	// We will unflatten on-demand instead as needed, but not all possible cases can be supported, especially with arrays.
7065 
7066 	auto *var = maybe_get_backing_variable(ops[2]);
7067 	bool patch = false;
7068 	bool flat_data = false;
7069 	bool ptr_is_chain = false;
7070 	bool flatten_composites = false;
7071 
7072 	bool is_block = false;
7073 
7074 	if (var)
7075 		is_block = has_decoration(get_variable_data_type(*var).self, DecorationBlock);
7076 
7077 	if (var)
7078 	{
7079 		flatten_composites = variable_storage_requires_stage_io(var->storage);
7080 		patch = has_decoration(ops[2], DecorationPatch) || is_patch_block(get_variable_data_type(*var));
7081 
7082 		// Should match strip_array in add_interface_block.
7083 		flat_data = var->storage == StorageClassInput ||
7084 		            (var->storage == StorageClassOutput && get_execution_model() == ExecutionModelTessellationControl);
7085 
7086 		// Patch inputs are treated as normal block IO variables, so they don't deal with this path at all.
7087 		if (patch && (!is_block || var->storage == StorageClassInput))
7088 			flat_data = false;
7089 
7090 		// We might have a chained access chain, where
7091 		// we first take the access chain to the control point, and then we chain into a member or something similar.
7092 		// In this case, we need to skip gl_in/gl_out remapping.
7093 		// Also, skip ptr chain for patches.
7094 		ptr_is_chain = var->self != ID(ops[2]);
7095 	}
7096 
7097 	bool builtin_variable = false;
7098 	bool variable_is_flat = false;
7099 
7100 	if (var && flat_data)
7101 	{
7102 		builtin_variable = is_builtin_variable(*var);
7103 
7104 		BuiltIn bi_type = BuiltInMax;
7105 		if (builtin_variable && !is_block)
7106 			bi_type = BuiltIn(get_decoration(var->self, DecorationBuiltIn));
7107 
7108 		variable_is_flat = !builtin_variable || is_block ||
7109 		                   bi_type == BuiltInPosition || bi_type == BuiltInPointSize ||
7110 		                   bi_type == BuiltInClipDistance || bi_type == BuiltInCullDistance;
7111 	}
7112 
7113 	if (variable_is_flat)
7114 	{
7115 		// If output is masked, it is emitted as a "normal" variable, just go through normal code paths.
7116 		// Only check this for the first level of access chain.
7117 		// Dealing with this for partial access chains should be possible, but awkward.
7118 		if (var->storage == StorageClassOutput && !ptr_is_chain)
7119 		{
7120 			bool masked = false;
7121 			if (is_block)
7122 			{
7123 				uint32_t relevant_member_index = patch ? 3 : 4;
7124 				// FIXME: This won't work properly if the application first access chains into gl_out element,
7125 				// then access chains into the member. Super weird, but theoretically possible ...
7126 				if (length > relevant_member_index)
7127 				{
7128 					uint32_t mbr_idx = get<SPIRConstant>(ops[relevant_member_index]).scalar();
7129 					masked = is_stage_output_block_member_masked(*var, mbr_idx, true);
7130 				}
7131 			}
7132 			else if (var)
7133 				masked = is_stage_output_variable_masked(*var);
7134 
7135 			if (masked)
7136 				return false;
7137 		}
7138 
7139 		AccessChainMeta meta;
7140 		SmallVector<uint32_t> indices;
7141 		uint32_t next_id = ir.increase_bound_by(1);
7142 
7143 		indices.reserve(length - 3 + 1);
7144 
7145 		uint32_t first_non_array_index = (ptr_is_chain ? 3 : 4) - (patch ? 1 : 0);
7146 
7147 		VariableID stage_var_id;
7148 		if (patch)
7149 			stage_var_id = var->storage == StorageClassInput ? patch_stage_in_var_id : patch_stage_out_var_id;
7150 		else
7151 			stage_var_id = var->storage == StorageClassInput ? stage_in_ptr_var_id : stage_out_ptr_var_id;
7152 
7153 		VariableID ptr = ptr_is_chain ? VariableID(ops[2]) : stage_var_id;
7154 		if (!ptr_is_chain && !patch)
7155 		{
7156 			// Index into gl_in/gl_out with first array index.
7157 			indices.push_back(ops[first_non_array_index - 1]);
7158 		}
7159 
7160 		auto &result_ptr_type = get<SPIRType>(ops[0]);
7161 
7162 		uint32_t const_mbr_id = next_id++;
7163 		uint32_t index = get_extended_decoration(ops[2], SPIRVCrossDecorationInterfaceMemberIndex);
7164 
7165 		// If we have a pointer chain expression, and we are no longer pointing to a composite
7166 		// object, we are in the clear. There is no longer a need to flatten anything.
7167 		bool further_access_chain_is_trivial = false;
7168 		if (ptr_is_chain && flatten_composites)
7169 		{
7170 			auto &ptr_type = expression_type(ptr);
7171 			if (!is_array(ptr_type) && !is_matrix(ptr_type) && ptr_type.basetype != SPIRType::Struct)
7172 				further_access_chain_is_trivial = true;
7173 		}
7174 
7175 		if (!further_access_chain_is_trivial && (flatten_composites || is_block))
7176 		{
7177 			uint32_t i = first_non_array_index;
7178 			auto *type = &get_variable_element_type(*var);
7179 			if (index == uint32_t(-1) && length >= (first_non_array_index + 1))
7180 			{
7181 				// Maybe this is a struct type in the input class, in which case
7182 				// we put it as a decoration on the corresponding member.
7183 				uint32_t mbr_idx = get_constant(ops[first_non_array_index]).scalar();
7184 				index = get_extended_member_decoration(var->self, mbr_idx,
7185 				                                       SPIRVCrossDecorationInterfaceMemberIndex);
7186 				assert(index != uint32_t(-1));
7187 				i++;
7188 				type = &get<SPIRType>(type->member_types[mbr_idx]);
7189 			}
7190 
7191 			// In this case, we're poking into flattened structures and arrays, so now we have to
7192 			// combine the following indices. If we encounter a non-constant index,
7193 			// we're hosed.
7194 			for (; flatten_composites && i < length; ++i)
7195 			{
7196 				if (!is_array(*type) && !is_matrix(*type) && type->basetype != SPIRType::Struct)
7197 					break;
7198 
7199 				auto *c = maybe_get<SPIRConstant>(ops[i]);
7200 				if (!c || c->specialization)
7201 					SPIRV_CROSS_THROW("Trying to dynamically index into an array interface variable in tessellation. "
7202 					                  "This is currently unsupported.");
7203 
7204 				// We're in flattened space, so just increment the member index into IO block.
7205 				// We can only do this once in the current implementation, so either:
7206 				// Struct, Matrix or 1-dimensional array for a control point.
7207 				if (type->basetype == SPIRType::Struct && var->storage == StorageClassOutput)
7208 				{
7209 					// Need to consider holes, since individual block members might be masked away.
7210 					uint32_t mbr_idx = c->scalar();
7211 					for (uint32_t j = 0; j < mbr_idx; j++)
7212 						if (!is_stage_output_block_member_masked(*var, j, true))
7213 							index++;
7214 				}
7215 				else
7216 					index += c->scalar();
7217 
7218 				if (type->parent_type)
7219 					type = &get<SPIRType>(type->parent_type);
7220 				else if (type->basetype == SPIRType::Struct)
7221 					type = &get<SPIRType>(type->member_types[c->scalar()]);
7222 			}
7223 
7224 			// We're not going to emit the actual member name, we let any further OpLoad take care of that.
7225 			// Tag the access chain with the member index we're referencing.
7226 			bool defer_access_chain = flatten_composites && (is_matrix(result_ptr_type) || is_array(result_ptr_type) ||
7227 			                                                 result_ptr_type.basetype == SPIRType::Struct);
7228 
7229 			if (!defer_access_chain)
7230 			{
7231 				// Access the appropriate member of gl_in/gl_out.
7232 				set<SPIRConstant>(const_mbr_id, get_uint_type_id(), index, false);
7233 				indices.push_back(const_mbr_id);
7234 
7235 				// Member index is now irrelevant.
7236 				index = uint32_t(-1);
7237 
7238 				// Append any straggling access chain indices.
7239 				if (i < length)
7240 					indices.insert(indices.end(), ops + i, ops + length);
7241 			}
7242 			else
7243 			{
7244 				// We must have consumed the entire access chain if we're deferring it.
7245 				assert(i == length);
7246 			}
7247 
7248 			if (index != uint32_t(-1))
7249 				set_extended_decoration(ops[1], SPIRVCrossDecorationInterfaceMemberIndex, index);
7250 			else
7251 				unset_extended_decoration(ops[1], SPIRVCrossDecorationInterfaceMemberIndex);
7252 		}
7253 		else
7254 		{
7255 			if (index != uint32_t(-1))
7256 			{
7257 				set<SPIRConstant>(const_mbr_id, get_uint_type_id(), index, false);
7258 				indices.push_back(const_mbr_id);
7259 			}
7260 
7261 			// Member index is now irrelevant.
7262 			index = uint32_t(-1);
7263 			unset_extended_decoration(ops[1], SPIRVCrossDecorationInterfaceMemberIndex);
7264 
7265 			indices.insert(indices.end(), ops + first_non_array_index, ops + length);
7266 		}
7267 
7268 		// We use the pointer to the base of the input/output array here,
7269 		// so this is always a pointer chain.
7270 		string e;
7271 
7272 		if (!ptr_is_chain)
7273 		{
7274 			// This is the start of an access chain, use ptr_chain to index into control point array.
7275 			e = access_chain(ptr, indices.data(), uint32_t(indices.size()), result_ptr_type, &meta, !patch);
7276 		}
7277 		else
7278 		{
7279 			// If we're accessing a struct, we need to use member indices which are based on the IO block,
7280 			// not actual struct type, so we have to use a split access chain here where
7281 			// first path resolves the control point index, i.e. gl_in[index], and second half deals with
7282 			// looking up flattened member name.
7283 
7284 			// However, it is possible that we partially accessed a struct,
7285 			// by taking pointer to member inside the control-point array.
7286 			// For this case, we fall back to a natural access chain since we have already dealt with remapping struct members.
7287 			// One way to check this here is if we have 2 implied read expressions.
7288 			// First one is the gl_in/gl_out struct itself, then an index into that array.
7289 			// If we have traversed further, we use a normal access chain formulation.
7290 			auto *ptr_expr = maybe_get<SPIRExpression>(ptr);
7291 			bool split_access_chain_formulation = flatten_composites && ptr_expr &&
7292 			                                      ptr_expr->implied_read_expressions.size() == 2 &&
7293 			                                      !further_access_chain_is_trivial;
7294 
7295 			if (split_access_chain_formulation)
7296 			{
7297 				e = join(to_expression(ptr),
7298 				         access_chain_internal(stage_var_id, indices.data(), uint32_t(indices.size()),
7299 				                               ACCESS_CHAIN_CHAIN_ONLY_BIT, &meta));
7300 			}
7301 			else
7302 			{
7303 				e = access_chain_internal(ptr, indices.data(), uint32_t(indices.size()), 0, &meta);
7304 			}
7305 		}
7306 
7307 		// Get the actual type of the object that was accessed. If it's a vector type and we changed it,
7308 		// then we'll need to add a swizzle.
7309 		// For this, we can't necessarily rely on the type of the base expression, because it might be
7310 		// another access chain, and it will therefore already have the "correct" type.
7311 		auto *expr_type = &get_variable_data_type(*var);
7312 		if (has_extended_decoration(ops[2], SPIRVCrossDecorationTessIOOriginalInputTypeID))
7313 			expr_type = &get<SPIRType>(get_extended_decoration(ops[2], SPIRVCrossDecorationTessIOOriginalInputTypeID));
7314 		for (uint32_t i = 3; i < length; i++)
7315 		{
7316 			if (!is_array(*expr_type) && expr_type->basetype == SPIRType::Struct)
7317 				expr_type = &get<SPIRType>(expr_type->member_types[get<SPIRConstant>(ops[i]).scalar()]);
7318 			else
7319 				expr_type = &get<SPIRType>(expr_type->parent_type);
7320 		}
7321 		if (!is_array(*expr_type) && !is_matrix(*expr_type) && expr_type->basetype != SPIRType::Struct &&
7322 		    expr_type->vecsize > result_ptr_type.vecsize)
7323 			e += vector_swizzle(result_ptr_type.vecsize, 0);
7324 
7325 		auto &expr = set<SPIRExpression>(ops[1], move(e), ops[0], should_forward(ops[2]));
7326 		expr.loaded_from = var->self;
7327 		expr.need_transpose = meta.need_transpose;
7328 		expr.access_chain = true;
7329 
7330 		// Mark the result as being packed if necessary.
7331 		if (meta.storage_is_packed)
7332 			set_extended_decoration(ops[1], SPIRVCrossDecorationPhysicalTypePacked);
7333 		if (meta.storage_physical_type != 0)
7334 			set_extended_decoration(ops[1], SPIRVCrossDecorationPhysicalTypeID, meta.storage_physical_type);
7335 		if (meta.storage_is_invariant)
7336 			set_decoration(ops[1], DecorationInvariant);
7337 		// Save the type we found in case the result is used in another access chain.
7338 		set_extended_decoration(ops[1], SPIRVCrossDecorationTessIOOriginalInputTypeID, expr_type->self);
7339 
7340 		// If we have some expression dependencies in our access chain, this access chain is technically a forwarded
7341 		// temporary which could be subject to invalidation.
7342 		// Need to assume we're forwarded while calling inherit_expression_depdendencies.
7343 		forwarded_temporaries.insert(ops[1]);
7344 		// The access chain itself is never forced to a temporary, but its dependencies might.
7345 		suppressed_usage_tracking.insert(ops[1]);
7346 
7347 		for (uint32_t i = 2; i < length; i++)
7348 		{
7349 			inherit_expression_dependencies(ops[1], ops[i]);
7350 			add_implied_read_expression(expr, ops[i]);
7351 		}
7352 
7353 		// If we have no dependencies after all, i.e., all indices in the access chain are immutable temporaries,
7354 		// we're not forwarded after all.
7355 		if (expr.expression_dependencies.empty())
7356 			forwarded_temporaries.erase(ops[1]);
7357 
7358 		return true;
7359 	}
7360 
7361 	// If this is the inner tessellation level, and we're tessellating triangles,
7362 	// drop the last index. It isn't an array in this case, so we can't have an
7363 	// array reference here. We need to make this ID a variable instead of an
7364 	// expression so we don't try to dereference it as a variable pointer.
7365 	// Don't do this if the index is a constant 1, though. We need to drop stores
7366 	// to that one.
7367 	auto *m = ir.find_meta(var ? var->self : ID(0));
7368 	if (get_execution_model() == ExecutionModelTessellationControl && var && m &&
7369 	    m->decoration.builtin_type == BuiltInTessLevelInner && get_entry_point().flags.get(ExecutionModeTriangles))
7370 	{
7371 		auto *c = maybe_get<SPIRConstant>(ops[3]);
7372 		if (c && c->scalar() == 1)
7373 			return false;
7374 		auto &dest_var = set<SPIRVariable>(ops[1], *var);
7375 		dest_var.basetype = ops[0];
7376 		ir.meta[ops[1]] = ir.meta[ops[2]];
7377 		inherit_expression_dependencies(ops[1], ops[2]);
7378 		return true;
7379 	}
7380 
7381 	return false;
7382 }
7383 
is_out_of_bounds_tessellation_level(uint32_t id_lhs)7384 bool CompilerMSL::is_out_of_bounds_tessellation_level(uint32_t id_lhs)
7385 {
7386 	if (!get_entry_point().flags.get(ExecutionModeTriangles))
7387 		return false;
7388 
7389 	// In SPIR-V, TessLevelInner always has two elements and TessLevelOuter always has
7390 	// four. This is true even if we are tessellating triangles. This allows clients
7391 	// to use a single tessellation control shader with multiple tessellation evaluation
7392 	// shaders.
7393 	// In Metal, however, only the first element of TessLevelInner and the first three
7394 	// of TessLevelOuter are accessible. This stems from how in Metal, the tessellation
7395 	// levels must be stored to a dedicated buffer in a particular format that depends
7396 	// on the patch type. Therefore, in Triangles mode, any access to the second
7397 	// inner level or the fourth outer level must be dropped.
7398 	const auto *e = maybe_get<SPIRExpression>(id_lhs);
7399 	if (!e || !e->access_chain)
7400 		return false;
7401 	BuiltIn builtin = BuiltIn(get_decoration(e->loaded_from, DecorationBuiltIn));
7402 	if (builtin != BuiltInTessLevelInner && builtin != BuiltInTessLevelOuter)
7403 		return false;
7404 	auto *c = maybe_get<SPIRConstant>(e->implied_read_expressions[1]);
7405 	if (!c)
7406 		return false;
7407 	return (builtin == BuiltInTessLevelInner && c->scalar() == 1) ||
7408 	       (builtin == BuiltInTessLevelOuter && c->scalar() == 3);
7409 }
7410 
prepare_access_chain_for_scalar_access(std::string & expr,const SPIRType & type,spv::StorageClass storage,bool & is_packed)7411 void CompilerMSL::prepare_access_chain_for_scalar_access(std::string &expr, const SPIRType &type,
7412                                                          spv::StorageClass storage, bool &is_packed)
7413 {
7414 	// If there is any risk of writes happening with the access chain in question,
7415 	// and there is a risk of concurrent write access to other components,
7416 	// we must cast the access chain to a plain pointer to ensure we only access the exact scalars we expect.
7417 	// The MSL compiler refuses to allow component-level access for any non-packed vector types.
7418 	if (!is_packed && (storage == StorageClassStorageBuffer || storage == StorageClassWorkgroup))
7419 	{
7420 		const char *addr_space = storage == StorageClassWorkgroup ? "threadgroup" : "device";
7421 		expr = join("((", addr_space, " ", type_to_glsl(type), "*)&", enclose_expression(expr), ")");
7422 
7423 		// Further indexing should happen with packed rules (array index, not swizzle).
7424 		is_packed = true;
7425 	}
7426 }
7427 
access_chain_needs_stage_io_builtin_translation(uint32_t base)7428 bool CompilerMSL::access_chain_needs_stage_io_builtin_translation(uint32_t base)
7429 {
7430 	auto *var = maybe_get_backing_variable(base);
7431 	if (!var || !is_tessellation_shader())
7432 		return true;
7433 
7434 	// We only need to rewrite builtin access chains when accessing flattened builtins like gl_ClipDistance_N.
7435 	// Avoid overriding it back to just gl_ClipDistance.
7436 	// This can only happen in scenarios where we cannot flatten/unflatten access chains, so, the only case
7437 	// where this triggers is evaluation shader inputs.
7438 	bool redirect_builtin = get_execution_model() == ExecutionModelTessellationEvaluation ?
7439 	                        var->storage == StorageClassOutput : false;
7440 	return redirect_builtin;
7441 }
7442 
7443 // Sets the interface member index for an access chain to a pull-model interpolant.
fix_up_interpolant_access_chain(const uint32_t * ops,uint32_t length)7444 void CompilerMSL::fix_up_interpolant_access_chain(const uint32_t *ops, uint32_t length)
7445 {
7446 	auto *var = maybe_get_backing_variable(ops[2]);
7447 	if (!var || !pull_model_inputs.count(var->self))
7448 		return;
7449 	// Get the base index.
7450 	uint32_t interface_index;
7451 	auto &var_type = get_variable_data_type(*var);
7452 	auto &result_type = get<SPIRType>(ops[0]);
7453 	auto *type = &var_type;
7454 	if (has_extended_decoration(ops[2], SPIRVCrossDecorationInterfaceMemberIndex))
7455 	{
7456 		interface_index = get_extended_decoration(ops[2], SPIRVCrossDecorationInterfaceMemberIndex);
7457 	}
7458 	else
7459 	{
7460 		// Assume an access chain into a struct variable.
7461 		assert(var_type.basetype == SPIRType::Struct);
7462 		auto &c = get<SPIRConstant>(ops[3 + var_type.array.size()]);
7463 		interface_index =
7464 		    get_extended_member_decoration(var->self, c.scalar(), SPIRVCrossDecorationInterfaceMemberIndex);
7465 	}
7466 	// Accumulate indices. We'll have to skip over the one for the struct, if present, because we already accounted
7467 	// for that getting the base index.
7468 	for (uint32_t i = 3; i < length; ++i)
7469 	{
7470 		if (is_vector(*type) && !is_array(*type) && is_scalar(result_type))
7471 		{
7472 			// We don't want to combine the next index. Actually, we need to save it
7473 			// so we know to apply a swizzle to the result of the interpolation.
7474 			set_extended_decoration(ops[1], SPIRVCrossDecorationInterpolantComponentExpr, ops[i]);
7475 			break;
7476 		}
7477 
7478 		auto *c = maybe_get<SPIRConstant>(ops[i]);
7479 		if (!c || c->specialization)
7480 			SPIRV_CROSS_THROW("Trying to dynamically index into an array interface variable using pull-model "
7481 			                  "interpolation. This is currently unsupported.");
7482 
7483 		if (type->parent_type)
7484 			type = &get<SPIRType>(type->parent_type);
7485 		else if (type->basetype == SPIRType::Struct)
7486 			type = &get<SPIRType>(type->member_types[c->scalar()]);
7487 
7488 		if (!has_extended_decoration(ops[2], SPIRVCrossDecorationInterfaceMemberIndex) &&
7489 		    i - 3 == var_type.array.size())
7490 			continue;
7491 
7492 		interface_index += c->scalar();
7493 	}
7494 	// Save this to the access chain itself so we can recover it later when calling an interpolation function.
7495 	set_extended_decoration(ops[1], SPIRVCrossDecorationInterfaceMemberIndex, interface_index);
7496 }
7497 
7498 // Override for MSL-specific syntax instructions
emit_instruction(const Instruction & instruction)7499 void CompilerMSL::emit_instruction(const Instruction &instruction)
7500 {
7501 #define MSL_BOP(op) emit_binary_op(ops[0], ops[1], ops[2], ops[3], #op)
7502 #define MSL_BOP_CAST(op, type) \
7503 	emit_binary_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode))
7504 #define MSL_UOP(op) emit_unary_op(ops[0], ops[1], ops[2], #op)
7505 #define MSL_QFOP(op) emit_quaternary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], ops[5], #op)
7506 #define MSL_TFOP(op) emit_trinary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], #op)
7507 #define MSL_BFOP(op) emit_binary_func_op(ops[0], ops[1], ops[2], ops[3], #op)
7508 #define MSL_BFOP_CAST(op, type) \
7509 	emit_binary_func_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode))
7510 #define MSL_UFOP(op) emit_unary_func_op(ops[0], ops[1], ops[2], #op)
7511 #define MSL_UNORD_BOP(op) emit_binary_unord_op(ops[0], ops[1], ops[2], ops[3], #op)
7512 
7513 	auto ops = stream(instruction);
7514 	auto opcode = static_cast<Op>(instruction.op);
7515 
7516 	// If we need to do implicit bitcasts, make sure we do it with the correct type.
7517 	uint32_t integer_width = get_integer_width_for_instruction(instruction);
7518 	auto int_type = to_signed_basetype(integer_width);
7519 	auto uint_type = to_unsigned_basetype(integer_width);
7520 
7521 	switch (opcode)
7522 	{
7523 	case OpLoad:
7524 	{
7525 		uint32_t id = ops[1];
7526 		uint32_t ptr = ops[2];
7527 		if (is_tessellation_shader())
7528 		{
7529 			if (!emit_tessellation_io_load(ops[0], id, ptr))
7530 				CompilerGLSL::emit_instruction(instruction);
7531 		}
7532 		else
7533 		{
7534 			// Sample mask input for Metal is not an array
7535 			if (BuiltIn(get_decoration(ptr, DecorationBuiltIn)) == BuiltInSampleMask)
7536 				set_decoration(id, DecorationBuiltIn, BuiltInSampleMask);
7537 			CompilerGLSL::emit_instruction(instruction);
7538 		}
7539 		break;
7540 	}
7541 
7542 	// Comparisons
7543 	case OpIEqual:
7544 		MSL_BOP_CAST(==, int_type);
7545 		break;
7546 
7547 	case OpLogicalEqual:
7548 	case OpFOrdEqual:
7549 		MSL_BOP(==);
7550 		break;
7551 
7552 	case OpINotEqual:
7553 		MSL_BOP_CAST(!=, int_type);
7554 		break;
7555 
7556 	case OpLogicalNotEqual:
7557 	case OpFOrdNotEqual:
7558 		MSL_BOP(!=);
7559 		break;
7560 
7561 	case OpUGreaterThan:
7562 		MSL_BOP_CAST(>, uint_type);
7563 		break;
7564 
7565 	case OpSGreaterThan:
7566 		MSL_BOP_CAST(>, int_type);
7567 		break;
7568 
7569 	case OpFOrdGreaterThan:
7570 		MSL_BOP(>);
7571 		break;
7572 
7573 	case OpUGreaterThanEqual:
7574 		MSL_BOP_CAST(>=, uint_type);
7575 		break;
7576 
7577 	case OpSGreaterThanEqual:
7578 		MSL_BOP_CAST(>=, int_type);
7579 		break;
7580 
7581 	case OpFOrdGreaterThanEqual:
7582 		MSL_BOP(>=);
7583 		break;
7584 
7585 	case OpULessThan:
7586 		MSL_BOP_CAST(<, uint_type);
7587 		break;
7588 
7589 	case OpSLessThan:
7590 		MSL_BOP_CAST(<, int_type);
7591 		break;
7592 
7593 	case OpFOrdLessThan:
7594 		MSL_BOP(<);
7595 		break;
7596 
7597 	case OpULessThanEqual:
7598 		MSL_BOP_CAST(<=, uint_type);
7599 		break;
7600 
7601 	case OpSLessThanEqual:
7602 		MSL_BOP_CAST(<=, int_type);
7603 		break;
7604 
7605 	case OpFOrdLessThanEqual:
7606 		MSL_BOP(<=);
7607 		break;
7608 
7609 	case OpFUnordEqual:
7610 		MSL_UNORD_BOP(==);
7611 		break;
7612 
7613 	case OpFUnordNotEqual:
7614 		MSL_UNORD_BOP(!=);
7615 		break;
7616 
7617 	case OpFUnordGreaterThan:
7618 		MSL_UNORD_BOP(>);
7619 		break;
7620 
7621 	case OpFUnordGreaterThanEqual:
7622 		MSL_UNORD_BOP(>=);
7623 		break;
7624 
7625 	case OpFUnordLessThan:
7626 		MSL_UNORD_BOP(<);
7627 		break;
7628 
7629 	case OpFUnordLessThanEqual:
7630 		MSL_UNORD_BOP(<=);
7631 		break;
7632 
7633 	// Derivatives
7634 	case OpDPdx:
7635 	case OpDPdxFine:
7636 	case OpDPdxCoarse:
7637 		MSL_UFOP(dfdx);
7638 		register_control_dependent_expression(ops[1]);
7639 		break;
7640 
7641 	case OpDPdy:
7642 	case OpDPdyFine:
7643 	case OpDPdyCoarse:
7644 		MSL_UFOP(dfdy);
7645 		register_control_dependent_expression(ops[1]);
7646 		break;
7647 
7648 	case OpFwidth:
7649 	case OpFwidthCoarse:
7650 	case OpFwidthFine:
7651 		MSL_UFOP(fwidth);
7652 		register_control_dependent_expression(ops[1]);
7653 		break;
7654 
7655 	// Bitfield
7656 	case OpBitFieldInsert:
7657 	{
7658 		emit_bitfield_insert_op(ops[0], ops[1], ops[2], ops[3], ops[4], ops[5], "insert_bits", SPIRType::UInt);
7659 		break;
7660 	}
7661 
7662 	case OpBitFieldSExtract:
7663 	{
7664 		emit_trinary_func_op_bitextract(ops[0], ops[1], ops[2], ops[3], ops[4], "extract_bits", int_type, int_type,
7665 		                                SPIRType::UInt, SPIRType::UInt);
7666 		break;
7667 	}
7668 
7669 	case OpBitFieldUExtract:
7670 	{
7671 		emit_trinary_func_op_bitextract(ops[0], ops[1], ops[2], ops[3], ops[4], "extract_bits", uint_type, uint_type,
7672 		                                SPIRType::UInt, SPIRType::UInt);
7673 		break;
7674 	}
7675 
7676 	case OpBitReverse:
7677 		// BitReverse does not have issues with sign since result type must match input type.
7678 		MSL_UFOP(reverse_bits);
7679 		break;
7680 
7681 	case OpBitCount:
7682 	{
7683 		auto basetype = expression_type(ops[2]).basetype;
7684 		emit_unary_func_op_cast(ops[0], ops[1], ops[2], "popcount", basetype, basetype);
7685 		break;
7686 	}
7687 
7688 	case OpFRem:
7689 		MSL_BFOP(fmod);
7690 		break;
7691 
7692 	case OpFMul:
7693 		if (msl_options.invariant_float_math || has_decoration(ops[1], DecorationNoContraction))
7694 			MSL_BFOP(spvFMul);
7695 		else
7696 			MSL_BOP(*);
7697 		break;
7698 
7699 	case OpFAdd:
7700 		if (msl_options.invariant_float_math || has_decoration(ops[1], DecorationNoContraction))
7701 			MSL_BFOP(spvFAdd);
7702 		else
7703 			MSL_BOP(+);
7704 		break;
7705 
7706 	case OpFSub:
7707 		if (msl_options.invariant_float_math || has_decoration(ops[1], DecorationNoContraction))
7708 			MSL_BFOP(spvFSub);
7709 		else
7710 			MSL_BOP(-);
7711 		break;
7712 
7713 	// Atomics
7714 	case OpAtomicExchange:
7715 	{
7716 		uint32_t result_type = ops[0];
7717 		uint32_t id = ops[1];
7718 		uint32_t ptr = ops[2];
7719 		uint32_t mem_sem = ops[4];
7720 		uint32_t val = ops[5];
7721 		emit_atomic_func_op(result_type, id, "atomic_exchange_explicit", mem_sem, mem_sem, false, ptr, val);
7722 		break;
7723 	}
7724 
7725 	case OpAtomicCompareExchange:
7726 	{
7727 		uint32_t result_type = ops[0];
7728 		uint32_t id = ops[1];
7729 		uint32_t ptr = ops[2];
7730 		uint32_t mem_sem_pass = ops[4];
7731 		uint32_t mem_sem_fail = ops[5];
7732 		uint32_t val = ops[6];
7733 		uint32_t comp = ops[7];
7734 		emit_atomic_func_op(result_type, id, "atomic_compare_exchange_weak_explicit", mem_sem_pass, mem_sem_fail, true,
7735 		                    ptr, comp, true, false, val);
7736 		break;
7737 	}
7738 
7739 	case OpAtomicCompareExchangeWeak:
7740 		SPIRV_CROSS_THROW("OpAtomicCompareExchangeWeak is only supported in kernel profile.");
7741 
7742 	case OpAtomicLoad:
7743 	{
7744 		uint32_t result_type = ops[0];
7745 		uint32_t id = ops[1];
7746 		uint32_t ptr = ops[2];
7747 		uint32_t mem_sem = ops[4];
7748 		emit_atomic_func_op(result_type, id, "atomic_load_explicit", mem_sem, mem_sem, false, ptr, 0);
7749 		break;
7750 	}
7751 
7752 	case OpAtomicStore:
7753 	{
7754 		uint32_t result_type = expression_type(ops[0]).self;
7755 		uint32_t id = ops[0];
7756 		uint32_t ptr = ops[0];
7757 		uint32_t mem_sem = ops[2];
7758 		uint32_t val = ops[3];
7759 		emit_atomic_func_op(result_type, id, "atomic_store_explicit", mem_sem, mem_sem, false, ptr, val);
7760 		break;
7761 	}
7762 
7763 #define MSL_AFMO_IMPL(op, valsrc, valconst)                                                                      \
7764 	do                                                                                                           \
7765 	{                                                                                                            \
7766 		uint32_t result_type = ops[0];                                                                           \
7767 		uint32_t id = ops[1];                                                                                    \
7768 		uint32_t ptr = ops[2];                                                                                   \
7769 		uint32_t mem_sem = ops[4];                                                                               \
7770 		uint32_t val = valsrc;                                                                                   \
7771 		emit_atomic_func_op(result_type, id, "atomic_fetch_" #op "_explicit", mem_sem, mem_sem, false, ptr, val, \
7772 		                    false, valconst);                                                                    \
7773 	} while (false)
7774 
7775 #define MSL_AFMO(op) MSL_AFMO_IMPL(op, ops[5], false)
7776 #define MSL_AFMIO(op) MSL_AFMO_IMPL(op, 1, true)
7777 
7778 	case OpAtomicIIncrement:
7779 		MSL_AFMIO(add);
7780 		break;
7781 
7782 	case OpAtomicIDecrement:
7783 		MSL_AFMIO(sub);
7784 		break;
7785 
7786 	case OpAtomicIAdd:
7787 		MSL_AFMO(add);
7788 		break;
7789 
7790 	case OpAtomicISub:
7791 		MSL_AFMO(sub);
7792 		break;
7793 
7794 	case OpAtomicSMin:
7795 	case OpAtomicUMin:
7796 		MSL_AFMO(min);
7797 		break;
7798 
7799 	case OpAtomicSMax:
7800 	case OpAtomicUMax:
7801 		MSL_AFMO(max);
7802 		break;
7803 
7804 	case OpAtomicAnd:
7805 		MSL_AFMO(and);
7806 		break;
7807 
7808 	case OpAtomicOr:
7809 		MSL_AFMO(or);
7810 		break;
7811 
7812 	case OpAtomicXor:
7813 		MSL_AFMO(xor);
7814 		break;
7815 
7816 	// Images
7817 
7818 	// Reads == Fetches in Metal
7819 	case OpImageRead:
7820 	{
7821 		// Mark that this shader reads from this image
7822 		uint32_t img_id = ops[2];
7823 		auto &type = expression_type(img_id);
7824 		if (type.image.dim != DimSubpassData)
7825 		{
7826 			auto *p_var = maybe_get_backing_variable(img_id);
7827 			if (p_var && has_decoration(p_var->self, DecorationNonReadable))
7828 			{
7829 				unset_decoration(p_var->self, DecorationNonReadable);
7830 				force_recompile();
7831 			}
7832 		}
7833 
7834 		emit_texture_op(instruction, false);
7835 		break;
7836 	}
7837 
7838 	// Emulate texture2D atomic operations
7839 	case OpImageTexelPointer:
7840 	{
7841 		// When using the pointer, we need to know which variable it is actually loaded from.
7842 		auto *var = maybe_get_backing_variable(ops[2]);
7843 		if (var && atomic_image_vars.count(var->self))
7844 		{
7845 			uint32_t result_type = ops[0];
7846 			uint32_t id = ops[1];
7847 
7848 			std::string coord = to_expression(ops[3]);
7849 			auto &type = expression_type(ops[2]);
7850 			if (type.image.dim == Dim2D)
7851 			{
7852 				coord = join("spvImage2DAtomicCoord(", coord, ", ", to_expression(ops[2]), ")");
7853 			}
7854 
7855 			auto &e = set<SPIRExpression>(id, join(to_expression(ops[2]), "_atomic[", coord, "]"), result_type, true);
7856 			e.loaded_from = var ? var->self : ID(0);
7857 			inherit_expression_dependencies(id, ops[3]);
7858 		}
7859 		else
7860 		{
7861 			uint32_t result_type = ops[0];
7862 			uint32_t id = ops[1];
7863 			auto &e =
7864 			    set<SPIRExpression>(id, join(to_expression(ops[2]), ", ", to_expression(ops[3])), result_type, true);
7865 
7866 			// When using the pointer, we need to know which variable it is actually loaded from.
7867 			e.loaded_from = var ? var->self : ID(0);
7868 			inherit_expression_dependencies(id, ops[3]);
7869 		}
7870 		break;
7871 	}
7872 
7873 	case OpImageWrite:
7874 	{
7875 		uint32_t img_id = ops[0];
7876 		uint32_t coord_id = ops[1];
7877 		uint32_t texel_id = ops[2];
7878 		const uint32_t *opt = &ops[3];
7879 		uint32_t length = instruction.length - 3;
7880 
7881 		// Bypass pointers because we need the real image struct
7882 		auto &type = expression_type(img_id);
7883 		auto &img_type = get<SPIRType>(type.self);
7884 
7885 		// Ensure this image has been marked as being written to and force a
7886 		// recommpile so that the image type output will include write access
7887 		auto *p_var = maybe_get_backing_variable(img_id);
7888 		if (p_var && has_decoration(p_var->self, DecorationNonWritable))
7889 		{
7890 			unset_decoration(p_var->self, DecorationNonWritable);
7891 			force_recompile();
7892 		}
7893 
7894 		bool forward = false;
7895 		uint32_t bias = 0;
7896 		uint32_t lod = 0;
7897 		uint32_t flags = 0;
7898 
7899 		if (length)
7900 		{
7901 			flags = *opt++;
7902 			length--;
7903 		}
7904 
7905 		auto test = [&](uint32_t &v, uint32_t flag) {
7906 			if (length && (flags & flag))
7907 			{
7908 				v = *opt++;
7909 				length--;
7910 			}
7911 		};
7912 
7913 		test(bias, ImageOperandsBiasMask);
7914 		test(lod, ImageOperandsLodMask);
7915 
7916 		auto &texel_type = expression_type(texel_id);
7917 		auto store_type = texel_type;
7918 		store_type.vecsize = 4;
7919 
7920 		TextureFunctionArguments args = {};
7921 		args.base.img = img_id;
7922 		args.base.imgtype = &img_type;
7923 		args.base.is_fetch = true;
7924 		args.coord = coord_id;
7925 		args.lod = lod;
7926 		statement(join(to_expression(img_id), ".write(",
7927 		               remap_swizzle(store_type, texel_type.vecsize, to_expression(texel_id)), ", ",
7928 		               CompilerMSL::to_function_args(args, &forward), ");"));
7929 
7930 		if (p_var && variable_storage_is_aliased(*p_var))
7931 			flush_all_aliased_variables();
7932 
7933 		break;
7934 	}
7935 
7936 	case OpImageQuerySize:
7937 	case OpImageQuerySizeLod:
7938 	{
7939 		uint32_t rslt_type_id = ops[0];
7940 		auto &rslt_type = get<SPIRType>(rslt_type_id);
7941 
7942 		uint32_t id = ops[1];
7943 
7944 		uint32_t img_id = ops[2];
7945 		string img_exp = to_expression(img_id);
7946 		auto &img_type = expression_type(img_id);
7947 		Dim img_dim = img_type.image.dim;
7948 		bool img_is_array = img_type.image.arrayed;
7949 
7950 		if (img_type.basetype != SPIRType::Image)
7951 			SPIRV_CROSS_THROW("Invalid type for OpImageQuerySize.");
7952 
7953 		string lod;
7954 		if (opcode == OpImageQuerySizeLod)
7955 		{
7956 			// LOD index defaults to zero, so don't bother outputing level zero index
7957 			string decl_lod = to_expression(ops[3]);
7958 			if (decl_lod != "0")
7959 				lod = decl_lod;
7960 		}
7961 
7962 		string expr = type_to_glsl(rslt_type) + "(";
7963 		expr += img_exp + ".get_width(" + lod + ")";
7964 
7965 		if (img_dim == Dim2D || img_dim == DimCube || img_dim == Dim3D)
7966 			expr += ", " + img_exp + ".get_height(" + lod + ")";
7967 
7968 		if (img_dim == Dim3D)
7969 			expr += ", " + img_exp + ".get_depth(" + lod + ")";
7970 
7971 		if (img_is_array)
7972 		{
7973 			expr += ", " + img_exp + ".get_array_size()";
7974 			if (img_dim == DimCube && msl_options.emulate_cube_array)
7975 				expr += " / 6";
7976 		}
7977 
7978 		expr += ")";
7979 
7980 		emit_op(rslt_type_id, id, expr, should_forward(img_id));
7981 
7982 		break;
7983 	}
7984 
7985 	case OpImageQueryLod:
7986 	{
7987 		if (!msl_options.supports_msl_version(2, 2))
7988 			SPIRV_CROSS_THROW("ImageQueryLod is only supported on MSL 2.2 and up.");
7989 		uint32_t result_type = ops[0];
7990 		uint32_t id = ops[1];
7991 		uint32_t image_id = ops[2];
7992 		uint32_t coord_id = ops[3];
7993 		emit_uninitialized_temporary_expression(result_type, id);
7994 
7995 		auto sampler_expr = to_sampler_expression(image_id);
7996 		auto *combined = maybe_get<SPIRCombinedImageSampler>(image_id);
7997 		auto image_expr = combined ? to_expression(combined->image) : to_expression(image_id);
7998 
7999 		// TODO: It is unclear if calculcate_clamped_lod also conditionally rounds
8000 		// the reported LOD based on the sampler. NEAREST miplevel should
8001 		// round the LOD, but LINEAR miplevel should not round.
8002 		// Let's hope this does not become an issue ...
8003 		statement(to_expression(id), ".x = ", image_expr, ".calculate_clamped_lod(", sampler_expr, ", ",
8004 		          to_expression(coord_id), ");");
8005 		statement(to_expression(id), ".y = ", image_expr, ".calculate_unclamped_lod(", sampler_expr, ", ",
8006 		          to_expression(coord_id), ");");
8007 		register_control_dependent_expression(id);
8008 		break;
8009 	}
8010 
8011 #define MSL_ImgQry(qrytype)                                                                 \
8012 	do                                                                                      \
8013 	{                                                                                       \
8014 		uint32_t rslt_type_id = ops[0];                                                     \
8015 		auto &rslt_type = get<SPIRType>(rslt_type_id);                                      \
8016 		uint32_t id = ops[1];                                                               \
8017 		uint32_t img_id = ops[2];                                                           \
8018 		string img_exp = to_expression(img_id);                                             \
8019 		string expr = type_to_glsl(rslt_type) + "(" + img_exp + ".get_num_" #qrytype "())"; \
8020 		emit_op(rslt_type_id, id, expr, should_forward(img_id));                            \
8021 	} while (false)
8022 
8023 	case OpImageQueryLevels:
8024 		MSL_ImgQry(mip_levels);
8025 		break;
8026 
8027 	case OpImageQuerySamples:
8028 		MSL_ImgQry(samples);
8029 		break;
8030 
8031 	case OpImage:
8032 	{
8033 		uint32_t result_type = ops[0];
8034 		uint32_t id = ops[1];
8035 		auto *combined = maybe_get<SPIRCombinedImageSampler>(ops[2]);
8036 
8037 		if (combined)
8038 		{
8039 			auto &e = emit_op(result_type, id, to_expression(combined->image), true, true);
8040 			auto *var = maybe_get_backing_variable(combined->image);
8041 			if (var)
8042 				e.loaded_from = var->self;
8043 		}
8044 		else
8045 		{
8046 			auto *var = maybe_get_backing_variable(ops[2]);
8047 			SPIRExpression *e;
8048 			if (var && has_extended_decoration(var->self, SPIRVCrossDecorationDynamicImageSampler))
8049 				e = &emit_op(result_type, id, join(to_expression(ops[2]), ".plane0"), true, true);
8050 			else
8051 				e = &emit_op(result_type, id, to_expression(ops[2]), true, true);
8052 			if (var)
8053 				e->loaded_from = var->self;
8054 		}
8055 		break;
8056 	}
8057 
8058 	// Casting
8059 	case OpQuantizeToF16:
8060 	{
8061 		uint32_t result_type = ops[0];
8062 		uint32_t id = ops[1];
8063 		uint32_t arg = ops[2];
8064 
8065 		string exp;
8066 		auto &type = get<SPIRType>(result_type);
8067 
8068 		switch (type.vecsize)
8069 		{
8070 		case 1:
8071 			exp = join("float(half(", to_expression(arg), "))");
8072 			break;
8073 		case 2:
8074 			exp = join("float2(half2(", to_expression(arg), "))");
8075 			break;
8076 		case 3:
8077 			exp = join("float3(half3(", to_expression(arg), "))");
8078 			break;
8079 		case 4:
8080 			exp = join("float4(half4(", to_expression(arg), "))");
8081 			break;
8082 		default:
8083 			SPIRV_CROSS_THROW("Illegal argument to OpQuantizeToF16.");
8084 		}
8085 
8086 		emit_op(result_type, id, exp, should_forward(arg));
8087 		break;
8088 	}
8089 
8090 	case OpInBoundsAccessChain:
8091 	case OpAccessChain:
8092 	case OpPtrAccessChain:
8093 		if (is_tessellation_shader())
8094 		{
8095 			if (!emit_tessellation_access_chain(ops, instruction.length))
8096 				CompilerGLSL::emit_instruction(instruction);
8097 		}
8098 		else
8099 			CompilerGLSL::emit_instruction(instruction);
8100 		fix_up_interpolant_access_chain(ops, instruction.length);
8101 		break;
8102 
8103 	case OpStore:
8104 		if (is_out_of_bounds_tessellation_level(ops[0]))
8105 			break;
8106 
8107 		if (maybe_emit_array_assignment(ops[0], ops[1]))
8108 			break;
8109 
8110 		CompilerGLSL::emit_instruction(instruction);
8111 		break;
8112 
8113 	// Compute barriers
8114 	case OpMemoryBarrier:
8115 		emit_barrier(0, ops[0], ops[1]);
8116 		break;
8117 
8118 	case OpControlBarrier:
8119 		// In GLSL a memory barrier is often followed by a control barrier.
8120 		// But in MSL, memory barriers are also control barriers, so don't
8121 		// emit a simple control barrier if a memory barrier has just been emitted.
8122 		if (previous_instruction_opcode != OpMemoryBarrier)
8123 			emit_barrier(ops[0], ops[1], ops[2]);
8124 		break;
8125 
8126 	case OpOuterProduct:
8127 	{
8128 		uint32_t result_type = ops[0];
8129 		uint32_t id = ops[1];
8130 		uint32_t a = ops[2];
8131 		uint32_t b = ops[3];
8132 
8133 		auto &type = get<SPIRType>(result_type);
8134 		string expr = type_to_glsl_constructor(type);
8135 		expr += "(";
8136 		for (uint32_t col = 0; col < type.columns; col++)
8137 		{
8138 			expr += to_enclosed_expression(a);
8139 			expr += " * ";
8140 			expr += to_extract_component_expression(b, col);
8141 			if (col + 1 < type.columns)
8142 				expr += ", ";
8143 		}
8144 		expr += ")";
8145 		emit_op(result_type, id, expr, should_forward(a) && should_forward(b));
8146 		inherit_expression_dependencies(id, a);
8147 		inherit_expression_dependencies(id, b);
8148 		break;
8149 	}
8150 
8151 	case OpVectorTimesMatrix:
8152 	case OpMatrixTimesVector:
8153 	{
8154 		if (!msl_options.invariant_float_math && !has_decoration(ops[1], DecorationNoContraction))
8155 		{
8156 			CompilerGLSL::emit_instruction(instruction);
8157 			break;
8158 		}
8159 
8160 		// If the matrix needs transpose, just flip the multiply order.
8161 		auto *e = maybe_get<SPIRExpression>(ops[opcode == OpMatrixTimesVector ? 2 : 3]);
8162 		if (e && e->need_transpose)
8163 		{
8164 			e->need_transpose = false;
8165 			string expr;
8166 
8167 			if (opcode == OpMatrixTimesVector)
8168 			{
8169 				expr = join("spvFMulVectorMatrix(", to_enclosed_unpacked_expression(ops[3]), ", ",
8170 				            to_unpacked_row_major_matrix_expression(ops[2]), ")");
8171 			}
8172 			else
8173 			{
8174 				expr = join("spvFMulMatrixVector(", to_unpacked_row_major_matrix_expression(ops[3]), ", ",
8175 				            to_enclosed_unpacked_expression(ops[2]), ")");
8176 			}
8177 
8178 			bool forward = should_forward(ops[2]) && should_forward(ops[3]);
8179 			emit_op(ops[0], ops[1], expr, forward);
8180 			e->need_transpose = true;
8181 			inherit_expression_dependencies(ops[1], ops[2]);
8182 			inherit_expression_dependencies(ops[1], ops[3]);
8183 		}
8184 		else
8185 		{
8186 			if (opcode == OpMatrixTimesVector)
8187 				MSL_BFOP(spvFMulMatrixVector);
8188 			else
8189 				MSL_BFOP(spvFMulVectorMatrix);
8190 		}
8191 		break;
8192 	}
8193 
8194 	case OpMatrixTimesMatrix:
8195 	{
8196 		if (!msl_options.invariant_float_math && !has_decoration(ops[1], DecorationNoContraction))
8197 		{
8198 			CompilerGLSL::emit_instruction(instruction);
8199 			break;
8200 		}
8201 
8202 		auto *a = maybe_get<SPIRExpression>(ops[2]);
8203 		auto *b = maybe_get<SPIRExpression>(ops[3]);
8204 
8205 		// If both matrices need transpose, we can multiply in flipped order and tag the expression as transposed.
8206 		// a^T * b^T = (b * a)^T.
8207 		if (a && b && a->need_transpose && b->need_transpose)
8208 		{
8209 			a->need_transpose = false;
8210 			b->need_transpose = false;
8211 
8212 			auto expr =
8213 			    join("spvFMulMatrixMatrix(", enclose_expression(to_unpacked_row_major_matrix_expression(ops[3])), ", ",
8214 			         enclose_expression(to_unpacked_row_major_matrix_expression(ops[2])), ")");
8215 
8216 			bool forward = should_forward(ops[2]) && should_forward(ops[3]);
8217 			auto &e = emit_op(ops[0], ops[1], expr, forward);
8218 			e.need_transpose = true;
8219 			a->need_transpose = true;
8220 			b->need_transpose = true;
8221 			inherit_expression_dependencies(ops[1], ops[2]);
8222 			inherit_expression_dependencies(ops[1], ops[3]);
8223 		}
8224 		else
8225 			MSL_BFOP(spvFMulMatrixMatrix);
8226 
8227 		break;
8228 	}
8229 
8230 	case OpIAddCarry:
8231 	case OpISubBorrow:
8232 	{
8233 		uint32_t result_type = ops[0];
8234 		uint32_t result_id = ops[1];
8235 		uint32_t op0 = ops[2];
8236 		uint32_t op1 = ops[3];
8237 		auto &type = get<SPIRType>(result_type);
8238 		emit_uninitialized_temporary_expression(result_type, result_id);
8239 
8240 		auto &res_type = get<SPIRType>(type.member_types[1]);
8241 		if (opcode == OpIAddCarry)
8242 		{
8243 			statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", to_enclosed_expression(op0), " + ",
8244 			          to_enclosed_expression(op1), ";");
8245 			statement(to_expression(result_id), ".", to_member_name(type, 1), " = select(", type_to_glsl(res_type),
8246 			          "(1), ", type_to_glsl(res_type), "(0), ", to_expression(result_id), ".", to_member_name(type, 0),
8247 			          " >= max(", to_expression(op0), ", ", to_expression(op1), "));");
8248 		}
8249 		else
8250 		{
8251 			statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", to_enclosed_expression(op0), " - ",
8252 			          to_enclosed_expression(op1), ";");
8253 			statement(to_expression(result_id), ".", to_member_name(type, 1), " = select(", type_to_glsl(res_type),
8254 			          "(1), ", type_to_glsl(res_type), "(0), ", to_enclosed_expression(op0),
8255 			          " >= ", to_enclosed_expression(op1), ");");
8256 		}
8257 		break;
8258 	}
8259 
8260 	case OpUMulExtended:
8261 	case OpSMulExtended:
8262 	{
8263 		uint32_t result_type = ops[0];
8264 		uint32_t result_id = ops[1];
8265 		uint32_t op0 = ops[2];
8266 		uint32_t op1 = ops[3];
8267 		auto &type = get<SPIRType>(result_type);
8268 		emit_uninitialized_temporary_expression(result_type, result_id);
8269 
8270 		statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", to_enclosed_expression(op0), " * ",
8271 		          to_enclosed_expression(op1), ";");
8272 		statement(to_expression(result_id), ".", to_member_name(type, 1), " = mulhi(", to_expression(op0), ", ",
8273 		          to_expression(op1), ");");
8274 		break;
8275 	}
8276 
8277 	case OpArrayLength:
8278 	{
8279 		auto &type = expression_type(ops[2]);
8280 		uint32_t offset = type_struct_member_offset(type, ops[3]);
8281 		uint32_t stride = type_struct_member_array_stride(type, ops[3]);
8282 
8283 		auto expr = join("(", to_buffer_size_expression(ops[2]), " - ", offset, ") / ", stride);
8284 		emit_op(ops[0], ops[1], expr, true);
8285 		break;
8286 	}
8287 
8288 	// SPV_INTEL_shader_integer_functions2
8289 	case OpUCountLeadingZerosINTEL:
8290 		MSL_UFOP(clz);
8291 		break;
8292 
8293 	case OpUCountTrailingZerosINTEL:
8294 		MSL_UFOP(ctz);
8295 		break;
8296 
8297 	case OpAbsISubINTEL:
8298 	case OpAbsUSubINTEL:
8299 		MSL_BFOP(absdiff);
8300 		break;
8301 
8302 	case OpIAddSatINTEL:
8303 	case OpUAddSatINTEL:
8304 		MSL_BFOP(addsat);
8305 		break;
8306 
8307 	case OpIAverageINTEL:
8308 	case OpUAverageINTEL:
8309 		MSL_BFOP(hadd);
8310 		break;
8311 
8312 	case OpIAverageRoundedINTEL:
8313 	case OpUAverageRoundedINTEL:
8314 		MSL_BFOP(rhadd);
8315 		break;
8316 
8317 	case OpISubSatINTEL:
8318 	case OpUSubSatINTEL:
8319 		MSL_BFOP(subsat);
8320 		break;
8321 
8322 	case OpIMul32x16INTEL:
8323 	{
8324 		uint32_t result_type = ops[0];
8325 		uint32_t id = ops[1];
8326 		uint32_t a = ops[2], b = ops[3];
8327 		bool forward = should_forward(a) && should_forward(b);
8328 		emit_op(result_type, id, join("int(short(", to_expression(a), ")) * int(short(", to_expression(b), "))"),
8329 		        forward);
8330 		inherit_expression_dependencies(id, a);
8331 		inherit_expression_dependencies(id, b);
8332 		break;
8333 	}
8334 
8335 	case OpUMul32x16INTEL:
8336 	{
8337 		uint32_t result_type = ops[0];
8338 		uint32_t id = ops[1];
8339 		uint32_t a = ops[2], b = ops[3];
8340 		bool forward = should_forward(a) && should_forward(b);
8341 		emit_op(result_type, id, join("uint(ushort(", to_expression(a), ")) * uint(ushort(", to_expression(b), "))"),
8342 		        forward);
8343 		inherit_expression_dependencies(id, a);
8344 		inherit_expression_dependencies(id, b);
8345 		break;
8346 	}
8347 
8348 	// SPV_EXT_demote_to_helper_invocation
8349 	case OpDemoteToHelperInvocationEXT:
8350 		if (!msl_options.supports_msl_version(2, 3))
8351 			SPIRV_CROSS_THROW("discard_fragment() does not formally have demote semantics until MSL 2.3.");
8352 		CompilerGLSL::emit_instruction(instruction);
8353 		break;
8354 
8355 	case OpIsHelperInvocationEXT:
8356 		if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 3))
8357 			SPIRV_CROSS_THROW("simd_is_helper_thread() requires MSL 2.3 on iOS.");
8358 		else if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 1))
8359 			SPIRV_CROSS_THROW("simd_is_helper_thread() requires MSL 2.1 on macOS.");
8360 		emit_op(ops[0], ops[1], "simd_is_helper_thread()", false);
8361 		break;
8362 
8363 	case OpBeginInvocationInterlockEXT:
8364 	case OpEndInvocationInterlockEXT:
8365 		if (!msl_options.supports_msl_version(2, 0))
8366 			SPIRV_CROSS_THROW("Raster order groups require MSL 2.0.");
8367 		break; // Nothing to do in the body
8368 
8369 	default:
8370 		CompilerGLSL::emit_instruction(instruction);
8371 		break;
8372 	}
8373 
8374 	previous_instruction_opcode = opcode;
8375 }
8376 
emit_texture_op(const Instruction & i,bool sparse)8377 void CompilerMSL::emit_texture_op(const Instruction &i, bool sparse)
8378 {
8379 	if (sparse)
8380 		SPIRV_CROSS_THROW("Sparse feedback not yet supported in MSL.");
8381 
8382 	if (msl_options.use_framebuffer_fetch_subpasses)
8383 	{
8384 		auto *ops = stream(i);
8385 
8386 		uint32_t result_type_id = ops[0];
8387 		uint32_t id = ops[1];
8388 		uint32_t img = ops[2];
8389 
8390 		auto &type = expression_type(img);
8391 		auto &imgtype = get<SPIRType>(type.self);
8392 
8393 		// Use Metal's native frame-buffer fetch API for subpass inputs.
8394 		if (imgtype.image.dim == DimSubpassData)
8395 		{
8396 			// Subpass inputs cannot be invalidated,
8397 			// so just forward the expression directly.
8398 			string expr = to_expression(img);
8399 			emit_op(result_type_id, id, expr, true);
8400 			return;
8401 		}
8402 	}
8403 
8404 	// Fallback to default implementation
8405 	CompilerGLSL::emit_texture_op(i, sparse);
8406 }
8407 
emit_barrier(uint32_t id_exe_scope,uint32_t id_mem_scope,uint32_t id_mem_sem)8408 void CompilerMSL::emit_barrier(uint32_t id_exe_scope, uint32_t id_mem_scope, uint32_t id_mem_sem)
8409 {
8410 	if (get_execution_model() != ExecutionModelGLCompute && get_execution_model() != ExecutionModelTessellationControl)
8411 		return;
8412 
8413 	uint32_t exe_scope = id_exe_scope ? evaluate_constant_u32(id_exe_scope) : uint32_t(ScopeInvocation);
8414 	uint32_t mem_scope = id_mem_scope ? evaluate_constant_u32(id_mem_scope) : uint32_t(ScopeInvocation);
8415 	// Use the wider of the two scopes (smaller value)
8416 	exe_scope = min(exe_scope, mem_scope);
8417 
8418 	if (msl_options.emulate_subgroups && exe_scope >= ScopeSubgroup && !id_mem_sem)
8419 		// In this case, we assume a "subgroup" size of 1. The barrier, then, is a noop.
8420 		return;
8421 
8422 	string bar_stmt;
8423 	if ((msl_options.is_ios() && msl_options.supports_msl_version(1, 2)) || msl_options.supports_msl_version(2))
8424 		bar_stmt = exe_scope < ScopeSubgroup ? "threadgroup_barrier" : "simdgroup_barrier";
8425 	else
8426 		bar_stmt = "threadgroup_barrier";
8427 	bar_stmt += "(";
8428 
8429 	uint32_t mem_sem = id_mem_sem ? evaluate_constant_u32(id_mem_sem) : uint32_t(MemorySemanticsMaskNone);
8430 
8431 	// Use the | operator to combine flags if we can.
8432 	if (msl_options.supports_msl_version(1, 2))
8433 	{
8434 		string mem_flags = "";
8435 		// For tesc shaders, this also affects objects in the Output storage class.
8436 		// Since in Metal, these are placed in a device buffer, we have to sync device memory here.
8437 		if (get_execution_model() == ExecutionModelTessellationControl ||
8438 		    (mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask)))
8439 			mem_flags += "mem_flags::mem_device";
8440 
8441 		// Fix tessellation patch function processing
8442 		if (get_execution_model() == ExecutionModelTessellationControl ||
8443 		    (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask)))
8444 		{
8445 			if (!mem_flags.empty())
8446 				mem_flags += " | ";
8447 			mem_flags += "mem_flags::mem_threadgroup";
8448 		}
8449 		if (mem_sem & MemorySemanticsImageMemoryMask)
8450 		{
8451 			if (!mem_flags.empty())
8452 				mem_flags += " | ";
8453 			mem_flags += "mem_flags::mem_texture";
8454 		}
8455 
8456 		if (mem_flags.empty())
8457 			mem_flags = "mem_flags::mem_none";
8458 
8459 		bar_stmt += mem_flags;
8460 	}
8461 	else
8462 	{
8463 		if ((mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask)) &&
8464 		    (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask)))
8465 			bar_stmt += "mem_flags::mem_device_and_threadgroup";
8466 		else if (mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask))
8467 			bar_stmt += "mem_flags::mem_device";
8468 		else if (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask))
8469 			bar_stmt += "mem_flags::mem_threadgroup";
8470 		else if (mem_sem & MemorySemanticsImageMemoryMask)
8471 			bar_stmt += "mem_flags::mem_texture";
8472 		else
8473 			bar_stmt += "mem_flags::mem_none";
8474 	}
8475 
8476 	bar_stmt += ");";
8477 
8478 	statement(bar_stmt);
8479 
8480 	assert(current_emitting_block);
8481 	flush_control_dependent_expressions(current_emitting_block->self);
8482 	flush_all_active_variables();
8483 }
8484 
storage_class_array_is_thread(StorageClass storage)8485 static bool storage_class_array_is_thread(StorageClass storage)
8486 {
8487 	switch (storage)
8488 	{
8489 	case StorageClassInput:
8490 	case StorageClassOutput:
8491 	case StorageClassGeneric:
8492 	case StorageClassFunction:
8493 	case StorageClassPrivate:
8494 		return true;
8495 
8496 	default:
8497 		return false;
8498 	}
8499 }
8500 
emit_array_copy(const string & lhs,uint32_t lhs_id,uint32_t rhs_id,StorageClass lhs_storage,StorageClass rhs_storage)8501 void CompilerMSL::emit_array_copy(const string &lhs, uint32_t lhs_id, uint32_t rhs_id,
8502 								  StorageClass lhs_storage, StorageClass rhs_storage)
8503 {
8504 	// Allow Metal to use the array<T> template to make arrays a value type.
8505 	// This, however, cannot be used for threadgroup address specifiers, so consider the custom array copy as fallback.
8506 	bool lhs_is_thread_storage = storage_class_array_is_thread(lhs_storage);
8507 	bool rhs_is_thread_storage = storage_class_array_is_thread(rhs_storage);
8508 
8509 	bool lhs_is_array_template = lhs_is_thread_storage;
8510 	bool rhs_is_array_template = rhs_is_thread_storage;
8511 
8512 	// Special considerations for stage IO variables.
8513 	// If the variable is actually backed by non-user visible device storage, we use array templates for those.
8514 	//
8515 	// Another special consideration is given to thread local variables which happen to have Offset decorations
8516 	// applied to them. Block-like types do not use array templates, so we need to force POD path if we detect
8517 	// these scenarios. This check isn't perfect since it would be technically possible to mix and match these things,
8518 	// and for a fully correct solution we might have to track array template state through access chains as well,
8519 	// but for all reasonable use cases, this should suffice.
8520 	// This special case should also only apply to Function/Private storage classes.
8521 	// We should not check backing variable for temporaries.
8522 	auto *lhs_var = maybe_get_backing_variable(lhs_id);
8523 	if (lhs_var && lhs_storage == StorageClassStorageBuffer && storage_class_array_is_thread(lhs_var->storage))
8524 		lhs_is_array_template = true;
8525 	else if (lhs_var && (lhs_storage == StorageClassFunction || lhs_storage == StorageClassPrivate) &&
8526 	         type_is_block_like(get<SPIRType>(lhs_var->basetype)))
8527 		lhs_is_array_template = false;
8528 
8529 	auto *rhs_var = maybe_get_backing_variable(rhs_id);
8530 	if (rhs_var && rhs_storage == StorageClassStorageBuffer && storage_class_array_is_thread(rhs_var->storage))
8531 		rhs_is_array_template = true;
8532 	else if (rhs_var && (rhs_storage == StorageClassFunction || rhs_storage == StorageClassPrivate) &&
8533 	         type_is_block_like(get<SPIRType>(rhs_var->basetype)))
8534 		rhs_is_array_template = false;
8535 
8536 	// If threadgroup storage qualifiers are *not* used:
8537 	// Avoid spvCopy* wrapper functions; Otherwise, spvUnsafeArray<> template cannot be used with that storage qualifier.
8538 	if (lhs_is_array_template && rhs_is_array_template && !using_builtin_array())
8539 	{
8540 		statement(lhs, " = ", to_expression(rhs_id), ";");
8541 	}
8542 	else
8543 	{
8544 		// Assignment from an array initializer is fine.
8545 		auto &type = expression_type(rhs_id);
8546 		auto *var = maybe_get_backing_variable(rhs_id);
8547 
8548 		// Unfortunately, we cannot template on address space in MSL,
8549 		// so explicit address space redirection it is ...
8550 		bool is_constant = false;
8551 		if (ir.ids[rhs_id].get_type() == TypeConstant)
8552 		{
8553 			is_constant = true;
8554 		}
8555 		else if (var && var->remapped_variable && var->statically_assigned &&
8556 		         ir.ids[var->static_expression].get_type() == TypeConstant)
8557 		{
8558 			is_constant = true;
8559 		}
8560 		else if (rhs_storage == StorageClassUniform)
8561 		{
8562 			is_constant = true;
8563 		}
8564 
8565 		// For the case where we have OpLoad triggering an array copy,
8566 		// we cannot easily detect this case ahead of time since it's
8567 		// context dependent. We might have to force a recompile here
8568 		// if this is the only use of array copies in our shader.
8569 		if (type.array.size() > 1)
8570 		{
8571 			if (type.array.size() > kArrayCopyMultidimMax)
8572 				SPIRV_CROSS_THROW("Cannot support this many dimensions for arrays of arrays.");
8573 			auto func = static_cast<SPVFuncImpl>(SPVFuncImplArrayCopyMultidimBase + type.array.size());
8574 			add_spv_func_and_recompile(func);
8575 		}
8576 		else
8577 			add_spv_func_and_recompile(SPVFuncImplArrayCopy);
8578 
8579 		const char *tag = nullptr;
8580 		if (lhs_is_thread_storage && is_constant)
8581 			tag = "FromConstantToStack";
8582 		else if (lhs_storage == StorageClassWorkgroup && is_constant)
8583 			tag = "FromConstantToThreadGroup";
8584 		else if (lhs_is_thread_storage && rhs_is_thread_storage)
8585 			tag = "FromStackToStack";
8586 		else if (lhs_storage == StorageClassWorkgroup && rhs_is_thread_storage)
8587 			tag = "FromStackToThreadGroup";
8588 		else if (lhs_is_thread_storage && rhs_storage == StorageClassWorkgroup)
8589 			tag = "FromThreadGroupToStack";
8590 		else if (lhs_storage == StorageClassWorkgroup && rhs_storage == StorageClassWorkgroup)
8591 			tag = "FromThreadGroupToThreadGroup";
8592 		else if (lhs_storage == StorageClassStorageBuffer && rhs_storage == StorageClassStorageBuffer)
8593 			tag = "FromDeviceToDevice";
8594 		else if (lhs_storage == StorageClassStorageBuffer && is_constant)
8595 			tag = "FromConstantToDevice";
8596 		else if (lhs_storage == StorageClassStorageBuffer && rhs_storage == StorageClassWorkgroup)
8597 			tag = "FromThreadGroupToDevice";
8598 		else if (lhs_storage == StorageClassStorageBuffer && rhs_is_thread_storage)
8599 			tag = "FromStackToDevice";
8600 		else if (lhs_storage == StorageClassWorkgroup && rhs_storage == StorageClassStorageBuffer)
8601 			tag = "FromDeviceToThreadGroup";
8602 		else if (lhs_is_thread_storage && rhs_storage == StorageClassStorageBuffer)
8603 			tag = "FromDeviceToStack";
8604 		else
8605 			SPIRV_CROSS_THROW("Unknown storage class used for copying arrays.");
8606 
8607 		// Pass internal array of spvUnsafeArray<> into wrapper functions
8608 		if (lhs_is_array_template && rhs_is_array_template && !msl_options.force_native_arrays)
8609 			statement("spvArrayCopy", tag, type.array.size(), "(", lhs, ".elements, ", to_expression(rhs_id), ".elements);");
8610 		if (lhs_is_array_template && !msl_options.force_native_arrays)
8611 			statement("spvArrayCopy", tag, type.array.size(), "(", lhs, ".elements, ", to_expression(rhs_id), ");");
8612 		else if (rhs_is_array_template && !msl_options.force_native_arrays)
8613 			statement("spvArrayCopy", tag, type.array.size(), "(", lhs, ", ", to_expression(rhs_id), ".elements);");
8614 		else
8615 			statement("spvArrayCopy", tag, type.array.size(), "(", lhs, ", ", to_expression(rhs_id), ");");
8616 	}
8617 }
8618 
get_physical_tess_level_array_size(spv::BuiltIn builtin) const8619 uint32_t CompilerMSL::get_physical_tess_level_array_size(spv::BuiltIn builtin) const
8620 {
8621 	if (get_execution_mode_bitset().get(ExecutionModeTriangles))
8622 		return builtin == BuiltInTessLevelInner ? 1 : 3;
8623 	else
8624 		return builtin == BuiltInTessLevelInner ? 2 : 4;
8625 }
8626 
8627 // Since MSL does not allow arrays to be copied via simple variable assignment,
8628 // if the LHS and RHS represent an assignment of an entire array, it must be
8629 // implemented by calling an array copy function.
8630 // Returns whether the struct assignment was emitted.
maybe_emit_array_assignment(uint32_t id_lhs,uint32_t id_rhs)8631 bool CompilerMSL::maybe_emit_array_assignment(uint32_t id_lhs, uint32_t id_rhs)
8632 {
8633 	// We only care about assignments of an entire array
8634 	auto &type = expression_type(id_rhs);
8635 	if (type.array.size() == 0)
8636 		return false;
8637 
8638 	auto *var = maybe_get<SPIRVariable>(id_lhs);
8639 
8640 	// Is this a remapped, static constant? Don't do anything.
8641 	if (var && var->remapped_variable && var->statically_assigned)
8642 		return true;
8643 
8644 	if (ir.ids[id_rhs].get_type() == TypeConstant && var && var->deferred_declaration)
8645 	{
8646 		// Special case, if we end up declaring a variable when assigning the constant array,
8647 		// we can avoid the copy by directly assigning the constant expression.
8648 		// This is likely necessary to be able to use a variable as a true look-up table, as it is unlikely
8649 		// the compiler will be able to optimize the spvArrayCopy() into a constant LUT.
8650 		// After a variable has been declared, we can no longer assign constant arrays in MSL unfortunately.
8651 		statement(to_expression(id_lhs), " = ", constant_expression(get<SPIRConstant>(id_rhs)), ";");
8652 		return true;
8653 	}
8654 
8655 	if (get_execution_model() == ExecutionModelTessellationControl &&
8656 	    has_decoration(id_lhs, DecorationBuiltIn))
8657 	{
8658 		auto builtin = BuiltIn(get_decoration(id_lhs, DecorationBuiltIn));
8659 		// Need to manually unroll the array store.
8660 		if (builtin == BuiltInTessLevelInner || builtin == BuiltInTessLevelOuter)
8661 		{
8662 			uint32_t array_size = get_physical_tess_level_array_size(builtin);
8663 			if (array_size == 1)
8664 				statement(to_expression(id_lhs), " = half(", to_expression(id_rhs), "[0]);");
8665 			else
8666 			{
8667 				for (uint32_t i = 0; i < array_size; i++)
8668 					statement(to_expression(id_lhs), "[", i, "] = half(", to_expression(id_rhs), "[", i, "]);");
8669 			}
8670 			return true;
8671 		}
8672 	}
8673 
8674 	// Ensure the LHS variable has been declared
8675 	auto *p_v_lhs = maybe_get_backing_variable(id_lhs);
8676 	if (p_v_lhs)
8677 		flush_variable_declaration(p_v_lhs->self);
8678 
8679 	auto lhs_storage = get_expression_effective_storage_class(id_lhs);
8680 	auto rhs_storage = get_expression_effective_storage_class(id_rhs);
8681 	emit_array_copy(to_expression(id_lhs), id_lhs, id_rhs, lhs_storage, rhs_storage);
8682 	register_write(id_lhs);
8683 
8684 	return true;
8685 }
8686 
8687 // Emits one of the atomic functions. In MSL, the atomic functions operate on pointers
emit_atomic_func_op(uint32_t result_type,uint32_t result_id,const char * op,uint32_t mem_order_1,uint32_t mem_order_2,bool has_mem_order_2,uint32_t obj,uint32_t op1,bool op1_is_pointer,bool op1_is_literal,uint32_t op2)8688 void CompilerMSL::emit_atomic_func_op(uint32_t result_type, uint32_t result_id, const char *op, uint32_t mem_order_1,
8689                                       uint32_t mem_order_2, bool has_mem_order_2, uint32_t obj, uint32_t op1,
8690                                       bool op1_is_pointer, bool op1_is_literal, uint32_t op2)
8691 {
8692 	string exp = string(op) + "(";
8693 
8694 	auto &type = get_pointee_type(expression_type(obj));
8695 	exp += "(";
8696 	auto *var = maybe_get_backing_variable(obj);
8697 	if (!var)
8698 		SPIRV_CROSS_THROW("No backing variable for atomic operation.");
8699 
8700 	// Emulate texture2D atomic operations
8701 	const auto &res_type = get<SPIRType>(var->basetype);
8702 	if (res_type.storage == StorageClassUniformConstant && res_type.basetype == SPIRType::Image)
8703 	{
8704 		exp += "device";
8705 	}
8706 	else
8707 	{
8708 		exp += get_argument_address_space(*var);
8709 	}
8710 
8711 	exp += " atomic_";
8712 	exp += type_to_glsl(type);
8713 	exp += "*)";
8714 
8715 	exp += "&";
8716 	exp += to_enclosed_expression(obj);
8717 
8718 	bool is_atomic_compare_exchange_strong = op1_is_pointer && op1;
8719 
8720 	if (is_atomic_compare_exchange_strong)
8721 	{
8722 		assert(strcmp(op, "atomic_compare_exchange_weak_explicit") == 0);
8723 		assert(op2);
8724 		assert(has_mem_order_2);
8725 		exp += ", &";
8726 		exp += to_name(result_id);
8727 		exp += ", ";
8728 		exp += to_expression(op2);
8729 		exp += ", ";
8730 		exp += get_memory_order(mem_order_1);
8731 		exp += ", ";
8732 		exp += get_memory_order(mem_order_2);
8733 		exp += ")";
8734 
8735 		// MSL only supports the weak atomic compare exchange, so emit a CAS loop here.
8736 		// The MSL function returns false if the atomic write fails OR the comparison test fails,
8737 		// so we must validate that it wasn't the comparison test that failed before continuing
8738 		// the CAS loop, otherwise it will loop infinitely, with the comparison test always failing.
8739 		// The function updates the comparitor value from the memory value, so the additional
8740 		// comparison test evaluates the memory value against the expected value.
8741 		emit_uninitialized_temporary_expression(result_type, result_id);
8742 		statement("do");
8743 		begin_scope();
8744 		statement(to_name(result_id), " = ", to_expression(op1), ";");
8745 		end_scope_decl(join("while (!", exp, " && ", to_name(result_id), " == ", to_enclosed_expression(op1), ")"));
8746 	}
8747 	else
8748 	{
8749 		assert(strcmp(op, "atomic_compare_exchange_weak_explicit") != 0);
8750 		if (op1)
8751 		{
8752 			if (op1_is_literal)
8753 				exp += join(", ", op1);
8754 			else
8755 				exp += ", " + to_expression(op1);
8756 		}
8757 		if (op2)
8758 			exp += ", " + to_expression(op2);
8759 
8760 		exp += string(", ") + get_memory_order(mem_order_1);
8761 		if (has_mem_order_2)
8762 			exp += string(", ") + get_memory_order(mem_order_2);
8763 
8764 		exp += ")";
8765 
8766 		if (strcmp(op, "atomic_store_explicit") != 0)
8767 			emit_op(result_type, result_id, exp, false);
8768 		else
8769 			statement(exp, ";");
8770 	}
8771 
8772 	flush_all_atomic_capable_variables();
8773 }
8774 
8775 // Metal only supports relaxed memory order for now
get_memory_order(uint32_t)8776 const char *CompilerMSL::get_memory_order(uint32_t)
8777 {
8778 	return "memory_order_relaxed";
8779 }
8780 
8781 // Override for MSL-specific extension syntax instructions
emit_glsl_op(uint32_t result_type,uint32_t id,uint32_t eop,const uint32_t * args,uint32_t count)8782 void CompilerMSL::emit_glsl_op(uint32_t result_type, uint32_t id, uint32_t eop, const uint32_t *args, uint32_t count)
8783 {
8784 	auto op = static_cast<GLSLstd450>(eop);
8785 
8786 	// If we need to do implicit bitcasts, make sure we do it with the correct type.
8787 	uint32_t integer_width = get_integer_width_for_glsl_instruction(op, args, count);
8788 	auto int_type = to_signed_basetype(integer_width);
8789 	auto uint_type = to_unsigned_basetype(integer_width);
8790 
8791 	switch (op)
8792 	{
8793 	case GLSLstd450Atan2:
8794 		emit_binary_func_op(result_type, id, args[0], args[1], "atan2");
8795 		break;
8796 	case GLSLstd450InverseSqrt:
8797 		emit_unary_func_op(result_type, id, args[0], "rsqrt");
8798 		break;
8799 	case GLSLstd450RoundEven:
8800 		emit_unary_func_op(result_type, id, args[0], "rint");
8801 		break;
8802 
8803 	case GLSLstd450FindILsb:
8804 	{
8805 		// In this template version of findLSB, we return T.
8806 		auto basetype = expression_type(args[0]).basetype;
8807 		emit_unary_func_op_cast(result_type, id, args[0], "spvFindLSB", basetype, basetype);
8808 		break;
8809 	}
8810 
8811 	case GLSLstd450FindSMsb:
8812 		emit_unary_func_op_cast(result_type, id, args[0], "spvFindSMSB", int_type, int_type);
8813 		break;
8814 
8815 	case GLSLstd450FindUMsb:
8816 		emit_unary_func_op_cast(result_type, id, args[0], "spvFindUMSB", uint_type, uint_type);
8817 		break;
8818 
8819 	case GLSLstd450PackSnorm4x8:
8820 		emit_unary_func_op(result_type, id, args[0], "pack_float_to_snorm4x8");
8821 		break;
8822 	case GLSLstd450PackUnorm4x8:
8823 		emit_unary_func_op(result_type, id, args[0], "pack_float_to_unorm4x8");
8824 		break;
8825 	case GLSLstd450PackSnorm2x16:
8826 		emit_unary_func_op(result_type, id, args[0], "pack_float_to_snorm2x16");
8827 		break;
8828 	case GLSLstd450PackUnorm2x16:
8829 		emit_unary_func_op(result_type, id, args[0], "pack_float_to_unorm2x16");
8830 		break;
8831 
8832 	case GLSLstd450PackHalf2x16:
8833 	{
8834 		auto expr = join("as_type<uint>(half2(", to_expression(args[0]), "))");
8835 		emit_op(result_type, id, expr, should_forward(args[0]));
8836 		inherit_expression_dependencies(id, args[0]);
8837 		break;
8838 	}
8839 
8840 	case GLSLstd450UnpackSnorm4x8:
8841 		emit_unary_func_op(result_type, id, args[0], "unpack_snorm4x8_to_float");
8842 		break;
8843 	case GLSLstd450UnpackUnorm4x8:
8844 		emit_unary_func_op(result_type, id, args[0], "unpack_unorm4x8_to_float");
8845 		break;
8846 	case GLSLstd450UnpackSnorm2x16:
8847 		emit_unary_func_op(result_type, id, args[0], "unpack_snorm2x16_to_float");
8848 		break;
8849 	case GLSLstd450UnpackUnorm2x16:
8850 		emit_unary_func_op(result_type, id, args[0], "unpack_unorm2x16_to_float");
8851 		break;
8852 
8853 	case GLSLstd450UnpackHalf2x16:
8854 	{
8855 		auto expr = join("float2(as_type<half2>(", to_expression(args[0]), "))");
8856 		emit_op(result_type, id, expr, should_forward(args[0]));
8857 		inherit_expression_dependencies(id, args[0]);
8858 		break;
8859 	}
8860 
8861 	case GLSLstd450PackDouble2x32:
8862 		emit_unary_func_op(result_type, id, args[0], "unsupported_GLSLstd450PackDouble2x32"); // Currently unsupported
8863 		break;
8864 	case GLSLstd450UnpackDouble2x32:
8865 		emit_unary_func_op(result_type, id, args[0], "unsupported_GLSLstd450UnpackDouble2x32"); // Currently unsupported
8866 		break;
8867 
8868 	case GLSLstd450MatrixInverse:
8869 	{
8870 		auto &mat_type = get<SPIRType>(result_type);
8871 		switch (mat_type.columns)
8872 		{
8873 		case 2:
8874 			emit_unary_func_op(result_type, id, args[0], "spvInverse2x2");
8875 			break;
8876 		case 3:
8877 			emit_unary_func_op(result_type, id, args[0], "spvInverse3x3");
8878 			break;
8879 		case 4:
8880 			emit_unary_func_op(result_type, id, args[0], "spvInverse4x4");
8881 			break;
8882 		default:
8883 			break;
8884 		}
8885 		break;
8886 	}
8887 
8888 	case GLSLstd450FMin:
8889 		// If the result type isn't float, don't bother calling the specific
8890 		// precise::/fast:: version. Metal doesn't have those for half and
8891 		// double types.
8892 		if (get<SPIRType>(result_type).basetype != SPIRType::Float)
8893 			emit_binary_func_op(result_type, id, args[0], args[1], "min");
8894 		else
8895 			emit_binary_func_op(result_type, id, args[0], args[1], "fast::min");
8896 		break;
8897 
8898 	case GLSLstd450FMax:
8899 		if (get<SPIRType>(result_type).basetype != SPIRType::Float)
8900 			emit_binary_func_op(result_type, id, args[0], args[1], "max");
8901 		else
8902 			emit_binary_func_op(result_type, id, args[0], args[1], "fast::max");
8903 		break;
8904 
8905 	case GLSLstd450FClamp:
8906 		// TODO: If args[1] is 0 and args[2] is 1, emit a saturate() call.
8907 		if (get<SPIRType>(result_type).basetype != SPIRType::Float)
8908 			emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "clamp");
8909 		else
8910 			emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "fast::clamp");
8911 		break;
8912 
8913 	case GLSLstd450NMin:
8914 		if (get<SPIRType>(result_type).basetype != SPIRType::Float)
8915 			emit_binary_func_op(result_type, id, args[0], args[1], "min");
8916 		else
8917 			emit_binary_func_op(result_type, id, args[0], args[1], "precise::min");
8918 		break;
8919 
8920 	case GLSLstd450NMax:
8921 		if (get<SPIRType>(result_type).basetype != SPIRType::Float)
8922 			emit_binary_func_op(result_type, id, args[0], args[1], "max");
8923 		else
8924 			emit_binary_func_op(result_type, id, args[0], args[1], "precise::max");
8925 		break;
8926 
8927 	case GLSLstd450NClamp:
8928 		// TODO: If args[1] is 0 and args[2] is 1, emit a saturate() call.
8929 		if (get<SPIRType>(result_type).basetype != SPIRType::Float)
8930 			emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "clamp");
8931 		else
8932 			emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "precise::clamp");
8933 		break;
8934 
8935 	case GLSLstd450InterpolateAtCentroid:
8936 	{
8937 		// We can't just emit the expression normally, because the qualified name contains a call to the default
8938 		// interpolate method, or refers to a local variable. We saved the interface index we need; use it to construct
8939 		// the base for the method call.
8940 		uint32_t interface_index = get_extended_decoration(args[0], SPIRVCrossDecorationInterfaceMemberIndex);
8941 		string component;
8942 		if (has_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr))
8943 		{
8944 			uint32_t index_expr = get_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr);
8945 			auto *c = maybe_get<SPIRConstant>(index_expr);
8946 			if (!c || c->specialization)
8947 				component = join("[", to_expression(index_expr), "]");
8948 			else
8949 				component = join(".", index_to_swizzle(c->scalar()));
8950 		}
8951 		emit_op(result_type, id,
8952 		        join(to_name(stage_in_var_id), ".", to_member_name(get_stage_in_struct_type(), interface_index),
8953 		             ".interpolate_at_centroid()", component),
8954 		        should_forward(args[0]));
8955 		break;
8956 	}
8957 
8958 	case GLSLstd450InterpolateAtSample:
8959 	{
8960 		uint32_t interface_index = get_extended_decoration(args[0], SPIRVCrossDecorationInterfaceMemberIndex);
8961 		string component;
8962 		if (has_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr))
8963 		{
8964 			uint32_t index_expr = get_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr);
8965 			auto *c = maybe_get<SPIRConstant>(index_expr);
8966 			if (!c || c->specialization)
8967 				component = join("[", to_expression(index_expr), "]");
8968 			else
8969 				component = join(".", index_to_swizzle(c->scalar()));
8970 		}
8971 		emit_op(result_type, id,
8972 		        join(to_name(stage_in_var_id), ".", to_member_name(get_stage_in_struct_type(), interface_index),
8973 		             ".interpolate_at_sample(", to_expression(args[1]), ")", component),
8974 		        should_forward(args[0]) && should_forward(args[1]));
8975 		break;
8976 	}
8977 
8978 	case GLSLstd450InterpolateAtOffset:
8979 	{
8980 		uint32_t interface_index = get_extended_decoration(args[0], SPIRVCrossDecorationInterfaceMemberIndex);
8981 		string component;
8982 		if (has_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr))
8983 		{
8984 			uint32_t index_expr = get_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr);
8985 			auto *c = maybe_get<SPIRConstant>(index_expr);
8986 			if (!c || c->specialization)
8987 				component = join("[", to_expression(index_expr), "]");
8988 			else
8989 				component = join(".", index_to_swizzle(c->scalar()));
8990 		}
8991 		// Like Direct3D, Metal puts the (0, 0) at the upper-left corner, not the center as SPIR-V and GLSL do.
8992 		// Offset the offset by (1/2 - 1/16), or 0.4375, to compensate for this.
8993 		// It has to be (1/2 - 1/16) and not 1/2, or several CTS tests subtly break on Intel.
8994 		emit_op(result_type, id,
8995 		        join(to_name(stage_in_var_id), ".", to_member_name(get_stage_in_struct_type(), interface_index),
8996 		             ".interpolate_at_offset(", to_expression(args[1]), " + 0.4375)", component),
8997 		        should_forward(args[0]) && should_forward(args[1]));
8998 		break;
8999 	}
9000 
9001 	case GLSLstd450Distance:
9002 		// MSL does not support scalar versions here.
9003 		if (expression_type(args[0]).vecsize == 1)
9004 		{
9005 			// Equivalent to length(a - b) -> abs(a - b).
9006 			emit_op(result_type, id,
9007 			        join("abs(", to_enclosed_unpacked_expression(args[0]), " - ",
9008 			             to_enclosed_unpacked_expression(args[1]), ")"),
9009 			        should_forward(args[0]) && should_forward(args[1]));
9010 			inherit_expression_dependencies(id, args[0]);
9011 			inherit_expression_dependencies(id, args[1]);
9012 		}
9013 		else
9014 			CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
9015 		break;
9016 
9017 	case GLSLstd450Length:
9018 		// MSL does not support scalar versions here.
9019 		if (expression_type(args[0]).vecsize == 1)
9020 		{
9021 			// Equivalent to abs().
9022 			emit_unary_func_op(result_type, id, args[0], "abs");
9023 		}
9024 		else
9025 			CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
9026 		break;
9027 
9028 	case GLSLstd450Normalize:
9029 		// MSL does not support scalar versions here.
9030 		if (expression_type(args[0]).vecsize == 1)
9031 		{
9032 			// Returns -1 or 1 for valid input, sign() does the job.
9033 			emit_unary_func_op(result_type, id, args[0], "sign");
9034 		}
9035 		else
9036 			CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
9037 		break;
9038 
9039 	case GLSLstd450Reflect:
9040 		if (get<SPIRType>(result_type).vecsize == 1)
9041 			emit_binary_func_op(result_type, id, args[0], args[1], "spvReflect");
9042 		else
9043 			CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
9044 		break;
9045 
9046 	case GLSLstd450Refract:
9047 		if (get<SPIRType>(result_type).vecsize == 1)
9048 			emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "spvRefract");
9049 		else
9050 			CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
9051 		break;
9052 
9053 	case GLSLstd450FaceForward:
9054 		if (get<SPIRType>(result_type).vecsize == 1)
9055 			emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "spvFaceForward");
9056 		else
9057 			CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
9058 		break;
9059 
9060 	case GLSLstd450Modf:
9061 	case GLSLstd450Frexp:
9062 	{
9063 		// Special case. If the variable is a scalar access chain, we cannot use it directly. We have to emit a temporary.
9064 		auto *ptr = maybe_get<SPIRExpression>(args[1]);
9065 		if (ptr && ptr->access_chain && is_scalar(expression_type(args[1])))
9066 		{
9067 			register_call_out_argument(args[1]);
9068 			forced_temporaries.insert(id);
9069 
9070 			// Need to create temporaries and copy over to access chain after.
9071 			// We cannot directly take the reference of a vector swizzle in MSL, even if it's scalar ...
9072 			uint32_t &tmp_id = extra_sub_expressions[id];
9073 			if (!tmp_id)
9074 				tmp_id = ir.increase_bound_by(1);
9075 
9076 			uint32_t tmp_type_id = get_pointee_type_id(ptr->expression_type);
9077 			emit_uninitialized_temporary_expression(tmp_type_id, tmp_id);
9078 			emit_binary_func_op(result_type, id, args[0], tmp_id, eop == GLSLstd450Modf ? "modf" : "frexp");
9079 			statement(to_expression(args[1]), " = ", to_expression(tmp_id), ";");
9080 		}
9081 		else
9082 			CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
9083 		break;
9084 	}
9085 
9086 	default:
9087 		CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
9088 		break;
9089 	}
9090 }
9091 
emit_spv_amd_shader_trinary_minmax_op(uint32_t result_type,uint32_t id,uint32_t eop,const uint32_t * args,uint32_t count)9092 void CompilerMSL::emit_spv_amd_shader_trinary_minmax_op(uint32_t result_type, uint32_t id, uint32_t eop,
9093                                                         const uint32_t *args, uint32_t count)
9094 {
9095 	enum AMDShaderTrinaryMinMax
9096 	{
9097 		FMin3AMD = 1,
9098 		UMin3AMD = 2,
9099 		SMin3AMD = 3,
9100 		FMax3AMD = 4,
9101 		UMax3AMD = 5,
9102 		SMax3AMD = 6,
9103 		FMid3AMD = 7,
9104 		UMid3AMD = 8,
9105 		SMid3AMD = 9
9106 	};
9107 
9108 	if (!msl_options.supports_msl_version(2, 1))
9109 		SPIRV_CROSS_THROW("Trinary min/max functions require MSL 2.1.");
9110 
9111 	auto op = static_cast<AMDShaderTrinaryMinMax>(eop);
9112 
9113 	switch (op)
9114 	{
9115 	case FMid3AMD:
9116 	case UMid3AMD:
9117 	case SMid3AMD:
9118 		emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "median3");
9119 		break;
9120 	default:
9121 		CompilerGLSL::emit_spv_amd_shader_trinary_minmax_op(result_type, id, eop, args, count);
9122 		break;
9123 	}
9124 }
9125 
9126 // Emit a structure declaration for the specified interface variable.
emit_interface_block(uint32_t ib_var_id)9127 void CompilerMSL::emit_interface_block(uint32_t ib_var_id)
9128 {
9129 	if (ib_var_id)
9130 	{
9131 		auto &ib_var = get<SPIRVariable>(ib_var_id);
9132 		auto &ib_type = get_variable_data_type(ib_var);
9133 		//assert(ib_type.basetype == SPIRType::Struct && !ib_type.member_types.empty());
9134 		assert(ib_type.basetype == SPIRType::Struct);
9135 		emit_struct(ib_type);
9136 	}
9137 }
9138 
9139 // Emits the declaration signature of the specified function.
9140 // If this is the entry point function, Metal-specific return value and function arguments are added.
emit_function_prototype(SPIRFunction & func,const Bitset &)9141 void CompilerMSL::emit_function_prototype(SPIRFunction &func, const Bitset &)
9142 {
9143 	if (func.self != ir.default_entry_point)
9144 		add_function_overload(func);
9145 
9146 	local_variable_names = resource_names;
9147 	string decl;
9148 
9149 	processing_entry_point = func.self == ir.default_entry_point;
9150 
9151 	// Metal helper functions must be static force-inline otherwise they will cause problems when linked together in a single Metallib.
9152 	if (!processing_entry_point)
9153 		statement(force_inline);
9154 
9155 	auto &type = get<SPIRType>(func.return_type);
9156 
9157 	if (!type.array.empty() && msl_options.force_native_arrays)
9158 	{
9159 		// We cannot return native arrays in MSL, so "return" through an out variable.
9160 		decl += "void";
9161 	}
9162 	else
9163 	{
9164 		decl += func_type_decl(type);
9165 	}
9166 
9167 	decl += " ";
9168 	decl += to_name(func.self);
9169 	decl += "(";
9170 
9171 	if (!type.array.empty() && msl_options.force_native_arrays)
9172 	{
9173 		// Fake arrays returns by writing to an out array instead.
9174 		decl += "thread ";
9175 		decl += type_to_glsl(type);
9176 		decl += " (&spvReturnValue)";
9177 		decl += type_to_array_glsl(type);
9178 		if (!func.arguments.empty())
9179 			decl += ", ";
9180 	}
9181 
9182 	if (processing_entry_point)
9183 	{
9184 		if (msl_options.argument_buffers)
9185 			decl += entry_point_args_argument_buffer(!func.arguments.empty());
9186 		else
9187 			decl += entry_point_args_classic(!func.arguments.empty());
9188 
9189 		// If entry point function has variables that require early declaration,
9190 		// ensure they each have an empty initializer, creating one if needed.
9191 		// This is done at this late stage because the initialization expression
9192 		// is cleared after each compilation pass.
9193 		for (auto var_id : vars_needing_early_declaration)
9194 		{
9195 			auto &ed_var = get<SPIRVariable>(var_id);
9196 			ID &initializer = ed_var.initializer;
9197 			if (!initializer)
9198 				initializer = ir.increase_bound_by(1);
9199 
9200 			// Do not override proper initializers.
9201 			if (ir.ids[initializer].get_type() == TypeNone || ir.ids[initializer].get_type() == TypeExpression)
9202 				set<SPIRExpression>(ed_var.initializer, "{}", ed_var.basetype, true);
9203 		}
9204 	}
9205 
9206 	for (auto &arg : func.arguments)
9207 	{
9208 		uint32_t name_id = arg.id;
9209 
9210 		auto *var = maybe_get<SPIRVariable>(arg.id);
9211 		if (var)
9212 		{
9213 			// If we need to modify the name of the variable, make sure we modify the original variable.
9214 			// Our alias is just a shadow variable.
9215 			if (arg.alias_global_variable && var->basevariable)
9216 				name_id = var->basevariable;
9217 
9218 			var->parameter = &arg; // Hold a pointer to the parameter so we can invalidate the readonly field if needed.
9219 		}
9220 
9221 		add_local_variable_name(name_id);
9222 
9223 		decl += argument_decl(arg);
9224 
9225 		bool is_dynamic_img_sampler = has_extended_decoration(arg.id, SPIRVCrossDecorationDynamicImageSampler);
9226 
9227 		auto &arg_type = get<SPIRType>(arg.type);
9228 		if (arg_type.basetype == SPIRType::SampledImage && !is_dynamic_img_sampler)
9229 		{
9230 			// Manufacture automatic plane args for multiplanar texture
9231 			uint32_t planes = 1;
9232 			if (auto *constexpr_sampler = find_constexpr_sampler(name_id))
9233 				if (constexpr_sampler->ycbcr_conversion_enable)
9234 					planes = constexpr_sampler->planes;
9235 			for (uint32_t i = 1; i < planes; i++)
9236 				decl += join(", ", argument_decl(arg), plane_name_suffix, i);
9237 
9238 			// Manufacture automatic sampler arg for SampledImage texture
9239 			if (arg_type.image.dim != DimBuffer)
9240 				decl += join(", thread const ", sampler_type(arg_type, arg.id), " ", to_sampler_expression(arg.id));
9241 		}
9242 
9243 		// Manufacture automatic swizzle arg.
9244 		if (msl_options.swizzle_texture_samples && has_sampled_images && is_sampled_image_type(arg_type) &&
9245 		    !is_dynamic_img_sampler)
9246 		{
9247 			bool arg_is_array = !arg_type.array.empty();
9248 			decl += join(", constant uint", arg_is_array ? "* " : "& ", to_swizzle_expression(arg.id));
9249 		}
9250 
9251 		if (buffers_requiring_array_length.count(name_id))
9252 		{
9253 			bool arg_is_array = !arg_type.array.empty();
9254 			decl += join(", constant uint", arg_is_array ? "* " : "& ", to_buffer_size_expression(name_id));
9255 		}
9256 
9257 		if (&arg != &func.arguments.back())
9258 			decl += ", ";
9259 	}
9260 
9261 	decl += ")";
9262 	statement(decl);
9263 }
9264 
needs_chroma_reconstruction(const MSLConstexprSampler * constexpr_sampler)9265 static bool needs_chroma_reconstruction(const MSLConstexprSampler *constexpr_sampler)
9266 {
9267 	// For now, only multiplanar images need explicit reconstruction. GBGR and BGRG images
9268 	// use implicit reconstruction.
9269 	return constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable && constexpr_sampler->planes > 1;
9270 }
9271 
9272 // Returns the texture sampling function string for the specified image and sampling characteristics.
to_function_name(const TextureFunctionNameArguments & args)9273 string CompilerMSL::to_function_name(const TextureFunctionNameArguments &args)
9274 {
9275 	VariableID img = args.base.img;
9276 	auto &imgtype = *args.base.imgtype;
9277 
9278 	const MSLConstexprSampler *constexpr_sampler = nullptr;
9279 	bool is_dynamic_img_sampler = false;
9280 	if (auto *var = maybe_get_backing_variable(img))
9281 	{
9282 		constexpr_sampler = find_constexpr_sampler(var->basevariable ? var->basevariable : VariableID(var->self));
9283 		is_dynamic_img_sampler = has_extended_decoration(var->self, SPIRVCrossDecorationDynamicImageSampler);
9284 	}
9285 
9286 	// Special-case gather. We have to alter the component being looked up
9287 	// in the swizzle case.
9288 	if (msl_options.swizzle_texture_samples && args.base.is_gather && !is_dynamic_img_sampler &&
9289 	    (!constexpr_sampler || !constexpr_sampler->ycbcr_conversion_enable))
9290 	{
9291 		add_spv_func_and_recompile(imgtype.image.depth ? SPVFuncImplGatherCompareSwizzle : SPVFuncImplGatherSwizzle);
9292 		return imgtype.image.depth ? "spvGatherCompareSwizzle" : "spvGatherSwizzle";
9293 	}
9294 
9295 	auto *combined = maybe_get<SPIRCombinedImageSampler>(img);
9296 
9297 	// Texture reference
9298 	string fname;
9299 	if (needs_chroma_reconstruction(constexpr_sampler) && !is_dynamic_img_sampler)
9300 	{
9301 		if (constexpr_sampler->planes != 2 && constexpr_sampler->planes != 3)
9302 			SPIRV_CROSS_THROW("Unhandled number of color image planes!");
9303 		// 444 images aren't downsampled, so we don't need to do linear filtering.
9304 		if (constexpr_sampler->resolution == MSL_FORMAT_RESOLUTION_444 ||
9305 		    constexpr_sampler->chroma_filter == MSL_SAMPLER_FILTER_NEAREST)
9306 		{
9307 			if (constexpr_sampler->planes == 2)
9308 				add_spv_func_and_recompile(SPVFuncImplChromaReconstructNearest2Plane);
9309 			else
9310 				add_spv_func_and_recompile(SPVFuncImplChromaReconstructNearest3Plane);
9311 			fname = "spvChromaReconstructNearest";
9312 		}
9313 		else // Linear with a downsampled format
9314 		{
9315 			fname = "spvChromaReconstructLinear";
9316 			switch (constexpr_sampler->resolution)
9317 			{
9318 			case MSL_FORMAT_RESOLUTION_444:
9319 				assert(false);
9320 				break; // not reached
9321 			case MSL_FORMAT_RESOLUTION_422:
9322 				switch (constexpr_sampler->x_chroma_offset)
9323 				{
9324 				case MSL_CHROMA_LOCATION_COSITED_EVEN:
9325 					if (constexpr_sampler->planes == 2)
9326 						add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear422CositedEven2Plane);
9327 					else
9328 						add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear422CositedEven3Plane);
9329 					fname += "422CositedEven";
9330 					break;
9331 				case MSL_CHROMA_LOCATION_MIDPOINT:
9332 					if (constexpr_sampler->planes == 2)
9333 						add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear422Midpoint2Plane);
9334 					else
9335 						add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear422Midpoint3Plane);
9336 					fname += "422Midpoint";
9337 					break;
9338 				default:
9339 					SPIRV_CROSS_THROW("Invalid chroma location.");
9340 				}
9341 				break;
9342 			case MSL_FORMAT_RESOLUTION_420:
9343 				fname += "420";
9344 				switch (constexpr_sampler->x_chroma_offset)
9345 				{
9346 				case MSL_CHROMA_LOCATION_COSITED_EVEN:
9347 					switch (constexpr_sampler->y_chroma_offset)
9348 					{
9349 					case MSL_CHROMA_LOCATION_COSITED_EVEN:
9350 						if (constexpr_sampler->planes == 2)
9351 							add_spv_func_and_recompile(
9352 							    SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven2Plane);
9353 						else
9354 							add_spv_func_and_recompile(
9355 							    SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven3Plane);
9356 						fname += "XCositedEvenYCositedEven";
9357 						break;
9358 					case MSL_CHROMA_LOCATION_MIDPOINT:
9359 						if (constexpr_sampler->planes == 2)
9360 							add_spv_func_and_recompile(
9361 							    SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint2Plane);
9362 						else
9363 							add_spv_func_and_recompile(
9364 							    SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint3Plane);
9365 						fname += "XCositedEvenYMidpoint";
9366 						break;
9367 					default:
9368 						SPIRV_CROSS_THROW("Invalid Y chroma location.");
9369 					}
9370 					break;
9371 				case MSL_CHROMA_LOCATION_MIDPOINT:
9372 					switch (constexpr_sampler->y_chroma_offset)
9373 					{
9374 					case MSL_CHROMA_LOCATION_COSITED_EVEN:
9375 						if (constexpr_sampler->planes == 2)
9376 							add_spv_func_and_recompile(
9377 							    SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven2Plane);
9378 						else
9379 							add_spv_func_and_recompile(
9380 							    SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven3Plane);
9381 						fname += "XMidpointYCositedEven";
9382 						break;
9383 					case MSL_CHROMA_LOCATION_MIDPOINT:
9384 						if (constexpr_sampler->planes == 2)
9385 							add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint2Plane);
9386 						else
9387 							add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint3Plane);
9388 						fname += "XMidpointYMidpoint";
9389 						break;
9390 					default:
9391 						SPIRV_CROSS_THROW("Invalid Y chroma location.");
9392 					}
9393 					break;
9394 				default:
9395 					SPIRV_CROSS_THROW("Invalid X chroma location.");
9396 				}
9397 				break;
9398 			default:
9399 				SPIRV_CROSS_THROW("Invalid format resolution.");
9400 			}
9401 		}
9402 	}
9403 	else
9404 	{
9405 		fname = to_expression(combined ? combined->image : img) + ".";
9406 
9407 		// Texture function and sampler
9408 		if (args.base.is_fetch)
9409 			fname += "read";
9410 		else if (args.base.is_gather)
9411 			fname += "gather";
9412 		else
9413 			fname += "sample";
9414 
9415 		if (args.has_dref)
9416 			fname += "_compare";
9417 	}
9418 
9419 	return fname;
9420 }
9421 
convert_to_f32(const string & expr,uint32_t components)9422 string CompilerMSL::convert_to_f32(const string &expr, uint32_t components)
9423 {
9424 	SPIRType t;
9425 	t.basetype = SPIRType::Float;
9426 	t.vecsize = components;
9427 	t.columns = 1;
9428 	return join(type_to_glsl_constructor(t), "(", expr, ")");
9429 }
9430 
sampling_type_needs_f32_conversion(const SPIRType & type)9431 static inline bool sampling_type_needs_f32_conversion(const SPIRType &type)
9432 {
9433 	// Double is not supported to begin with, but doesn't hurt to check for completion.
9434 	return type.basetype == SPIRType::Half || type.basetype == SPIRType::Double;
9435 }
9436 
9437 // Returns the function args for a texture sampling function for the specified image and sampling characteristics.
to_function_args(const TextureFunctionArguments & args,bool * p_forward)9438 string CompilerMSL::to_function_args(const TextureFunctionArguments &args, bool *p_forward)
9439 {
9440 	VariableID img = args.base.img;
9441 	auto &imgtype = *args.base.imgtype;
9442 	uint32_t lod = args.lod;
9443 	uint32_t grad_x = args.grad_x;
9444 	uint32_t grad_y = args.grad_y;
9445 	uint32_t bias = args.bias;
9446 
9447 	const MSLConstexprSampler *constexpr_sampler = nullptr;
9448 	bool is_dynamic_img_sampler = false;
9449 	if (auto *var = maybe_get_backing_variable(img))
9450 	{
9451 		constexpr_sampler = find_constexpr_sampler(var->basevariable ? var->basevariable : VariableID(var->self));
9452 		is_dynamic_img_sampler = has_extended_decoration(var->self, SPIRVCrossDecorationDynamicImageSampler);
9453 	}
9454 
9455 	string farg_str;
9456 	bool forward = true;
9457 
9458 	if (!is_dynamic_img_sampler)
9459 	{
9460 		// Texture reference (for some cases)
9461 		if (needs_chroma_reconstruction(constexpr_sampler))
9462 		{
9463 			// Multiplanar images need two or three textures.
9464 			farg_str += to_expression(img);
9465 			for (uint32_t i = 1; i < constexpr_sampler->planes; i++)
9466 				farg_str += join(", ", to_expression(img), plane_name_suffix, i);
9467 		}
9468 		else if ((!constexpr_sampler || !constexpr_sampler->ycbcr_conversion_enable) &&
9469 		         msl_options.swizzle_texture_samples && args.base.is_gather)
9470 		{
9471 			auto *combined = maybe_get<SPIRCombinedImageSampler>(img);
9472 			farg_str += to_expression(combined ? combined->image : img);
9473 		}
9474 
9475 		// Sampler reference
9476 		if (!args.base.is_fetch)
9477 		{
9478 			if (!farg_str.empty())
9479 				farg_str += ", ";
9480 			farg_str += to_sampler_expression(img);
9481 		}
9482 
9483 		if ((!constexpr_sampler || !constexpr_sampler->ycbcr_conversion_enable) &&
9484 		    msl_options.swizzle_texture_samples && args.base.is_gather)
9485 		{
9486 			// Add the swizzle constant from the swizzle buffer.
9487 			farg_str += ", " + to_swizzle_expression(img);
9488 			used_swizzle_buffer = true;
9489 		}
9490 
9491 		// Swizzled gather puts the component before the other args, to allow template
9492 		// deduction to work.
9493 		if (args.component && msl_options.swizzle_texture_samples)
9494 		{
9495 			forward = should_forward(args.component);
9496 			farg_str += ", " + to_component_argument(args.component);
9497 		}
9498 	}
9499 
9500 	// Texture coordinates
9501 	forward = forward && should_forward(args.coord);
9502 	auto coord_expr = to_enclosed_expression(args.coord);
9503 	auto &coord_type = expression_type(args.coord);
9504 	bool coord_is_fp = type_is_floating_point(coord_type);
9505 	bool is_cube_fetch = false;
9506 
9507 	string tex_coords = coord_expr;
9508 	uint32_t alt_coord_component = 0;
9509 
9510 	switch (imgtype.image.dim)
9511 	{
9512 
9513 	case Dim1D:
9514 		if (coord_type.vecsize > 1)
9515 			tex_coords = enclose_expression(tex_coords) + ".x";
9516 
9517 		if (args.base.is_fetch)
9518 			tex_coords = "uint(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
9519 		else if (sampling_type_needs_f32_conversion(coord_type))
9520 			tex_coords = convert_to_f32(tex_coords, 1);
9521 
9522 		if (msl_options.texture_1D_as_2D)
9523 		{
9524 			if (args.base.is_fetch)
9525 				tex_coords = "uint2(" + tex_coords + ", 0)";
9526 			else
9527 				tex_coords = "float2(" + tex_coords + ", 0.5)";
9528 		}
9529 
9530 		alt_coord_component = 1;
9531 		break;
9532 
9533 	case DimBuffer:
9534 		if (coord_type.vecsize > 1)
9535 			tex_coords = enclose_expression(tex_coords) + ".x";
9536 
9537 		if (msl_options.texture_buffer_native)
9538 		{
9539 			tex_coords = "uint(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
9540 		}
9541 		else
9542 		{
9543 			// Metal texel buffer textures are 2D, so convert 1D coord to 2D.
9544 			// Support for Metal 2.1's new texture_buffer type.
9545 			if (args.base.is_fetch)
9546 			{
9547 				if (msl_options.texel_buffer_texture_width > 0)
9548 				{
9549 					tex_coords = "spvTexelBufferCoord(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
9550 				}
9551 				else
9552 				{
9553 					tex_coords = "spvTexelBufferCoord(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ", " +
9554 					             to_expression(img) + ")";
9555 				}
9556 			}
9557 		}
9558 
9559 		alt_coord_component = 1;
9560 		break;
9561 
9562 	case DimSubpassData:
9563 		// If we're using Metal's native frame-buffer fetch API for subpass inputs,
9564 		// this path will not be hit.
9565 		tex_coords = "uint2(gl_FragCoord.xy)";
9566 		alt_coord_component = 2;
9567 		break;
9568 
9569 	case Dim2D:
9570 		if (coord_type.vecsize > 2)
9571 			tex_coords = enclose_expression(tex_coords) + ".xy";
9572 
9573 		if (args.base.is_fetch)
9574 			tex_coords = "uint2(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
9575 		else if (sampling_type_needs_f32_conversion(coord_type))
9576 			tex_coords = convert_to_f32(tex_coords, 2);
9577 
9578 		alt_coord_component = 2;
9579 		break;
9580 
9581 	case Dim3D:
9582 		if (coord_type.vecsize > 3)
9583 			tex_coords = enclose_expression(tex_coords) + ".xyz";
9584 
9585 		if (args.base.is_fetch)
9586 			tex_coords = "uint3(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
9587 		else if (sampling_type_needs_f32_conversion(coord_type))
9588 			tex_coords = convert_to_f32(tex_coords, 3);
9589 
9590 		alt_coord_component = 3;
9591 		break;
9592 
9593 	case DimCube:
9594 		if (args.base.is_fetch)
9595 		{
9596 			is_cube_fetch = true;
9597 			tex_coords += ".xy";
9598 			tex_coords = "uint2(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
9599 		}
9600 		else
9601 		{
9602 			if (coord_type.vecsize > 3)
9603 				tex_coords = enclose_expression(tex_coords) + ".xyz";
9604 		}
9605 
9606 		if (sampling_type_needs_f32_conversion(coord_type))
9607 			tex_coords = convert_to_f32(tex_coords, 3);
9608 
9609 		alt_coord_component = 3;
9610 		break;
9611 
9612 	default:
9613 		break;
9614 	}
9615 
9616 	if (args.base.is_fetch && (args.offset || args.coffset))
9617 	{
9618 		uint32_t offset_expr = args.offset ? args.offset : args.coffset;
9619 		// Fetch offsets must be applied directly to the coordinate.
9620 		forward = forward && should_forward(offset_expr);
9621 		auto &type = expression_type(offset_expr);
9622 		if (imgtype.image.dim == Dim1D && msl_options.texture_1D_as_2D)
9623 		{
9624 			if (type.basetype != SPIRType::UInt)
9625 				tex_coords += join(" + uint2(", bitcast_expression(SPIRType::UInt, offset_expr), ", 0)");
9626 			else
9627 				tex_coords += join(" + uint2(", to_enclosed_expression(offset_expr), ", 0)");
9628 		}
9629 		else
9630 		{
9631 			if (type.basetype != SPIRType::UInt)
9632 				tex_coords += " + " + bitcast_expression(SPIRType::UInt, offset_expr);
9633 			else
9634 				tex_coords += " + " + to_enclosed_expression(offset_expr);
9635 		}
9636 	}
9637 
9638 	// If projection, use alt coord as divisor
9639 	if (args.base.is_proj)
9640 	{
9641 		if (sampling_type_needs_f32_conversion(coord_type))
9642 			tex_coords += " / " + convert_to_f32(to_extract_component_expression(args.coord, alt_coord_component), 1);
9643 		else
9644 			tex_coords += " / " + to_extract_component_expression(args.coord, alt_coord_component);
9645 	}
9646 
9647 	if (!farg_str.empty())
9648 		farg_str += ", ";
9649 
9650 	if (imgtype.image.dim == DimCube && imgtype.image.arrayed && msl_options.emulate_cube_array)
9651 	{
9652 		farg_str += "spvCubemapTo2DArrayFace(" + tex_coords + ").xy";
9653 
9654 		if (is_cube_fetch)
9655 			farg_str += ", uint(" + to_extract_component_expression(args.coord, 2) + ")";
9656 		else
9657 			farg_str +=
9658 			    ", uint(spvCubemapTo2DArrayFace(" + tex_coords + ").z) + (uint(" +
9659 			    round_fp_tex_coords(to_extract_component_expression(args.coord, alt_coord_component), coord_is_fp) +
9660 			    ") * 6u)";
9661 
9662 		add_spv_func_and_recompile(SPVFuncImplCubemapTo2DArrayFace);
9663 	}
9664 	else
9665 	{
9666 		farg_str += tex_coords;
9667 
9668 		// If fetch from cube, add face explicitly
9669 		if (is_cube_fetch)
9670 		{
9671 			// Special case for cube arrays, face and layer are packed in one dimension.
9672 			if (imgtype.image.arrayed)
9673 				farg_str += ", uint(" + to_extract_component_expression(args.coord, 2) + ") % 6u";
9674 			else
9675 				farg_str +=
9676 				    ", uint(" + round_fp_tex_coords(to_extract_component_expression(args.coord, 2), coord_is_fp) + ")";
9677 		}
9678 
9679 		// If array, use alt coord
9680 		if (imgtype.image.arrayed)
9681 		{
9682 			// Special case for cube arrays, face and layer are packed in one dimension.
9683 			if (imgtype.image.dim == DimCube && args.base.is_fetch)
9684 			{
9685 				farg_str += ", uint(" + to_extract_component_expression(args.coord, 2) + ") / 6u";
9686 			}
9687 			else
9688 			{
9689 				farg_str +=
9690 				    ", uint(" +
9691 				    round_fp_tex_coords(to_extract_component_expression(args.coord, alt_coord_component), coord_is_fp) +
9692 				    ")";
9693 				if (imgtype.image.dim == DimSubpassData)
9694 				{
9695 					if (msl_options.multiview)
9696 						farg_str += " + gl_ViewIndex";
9697 					else if (msl_options.arrayed_subpass_input)
9698 						farg_str += " + gl_Layer";
9699 				}
9700 			}
9701 		}
9702 		else if (imgtype.image.dim == DimSubpassData)
9703 		{
9704 			if (msl_options.multiview)
9705 				farg_str += ", gl_ViewIndex";
9706 			else if (msl_options.arrayed_subpass_input)
9707 				farg_str += ", gl_Layer";
9708 		}
9709 	}
9710 
9711 	// Depth compare reference value
9712 	if (args.dref)
9713 	{
9714 		forward = forward && should_forward(args.dref);
9715 		farg_str += ", ";
9716 
9717 		auto &dref_type = expression_type(args.dref);
9718 
9719 		string dref_expr;
9720 		if (args.base.is_proj)
9721 			dref_expr = join(to_enclosed_expression(args.dref), " / ",
9722 			                 to_extract_component_expression(args.coord, alt_coord_component));
9723 		else
9724 			dref_expr = to_expression(args.dref);
9725 
9726 		if (sampling_type_needs_f32_conversion(dref_type))
9727 			dref_expr = convert_to_f32(dref_expr, 1);
9728 
9729 		farg_str += dref_expr;
9730 
9731 		if (msl_options.is_macos() && (grad_x || grad_y))
9732 		{
9733 			// For sample compare, MSL does not support gradient2d for all targets (only iOS apparently according to docs).
9734 			// However, the most common case here is to have a constant gradient of 0, as that is the only way to express
9735 			// LOD == 0 in GLSL with sampler2DArrayShadow (cascaded shadow mapping).
9736 			// We will detect a compile-time constant 0 value for gradient and promote that to level(0) on MSL.
9737 			bool constant_zero_x = !grad_x || expression_is_constant_null(grad_x);
9738 			bool constant_zero_y = !grad_y || expression_is_constant_null(grad_y);
9739 			if (constant_zero_x && constant_zero_y)
9740 			{
9741 				lod = 0;
9742 				grad_x = 0;
9743 				grad_y = 0;
9744 				farg_str += ", level(0)";
9745 			}
9746 			else if (!msl_options.supports_msl_version(2, 3))
9747 			{
9748 				SPIRV_CROSS_THROW("Using non-constant 0.0 gradient() qualifier for sample_compare. This is not "
9749 				                  "supported on macOS prior to MSL 2.3.");
9750 			}
9751 		}
9752 
9753 		if (msl_options.is_macos() && bias)
9754 		{
9755 			// Bias is not supported either on macOS with sample_compare.
9756 			// Verify it is compile-time zero, and drop the argument.
9757 			if (expression_is_constant_null(bias))
9758 			{
9759 				bias = 0;
9760 			}
9761 			else if (!msl_options.supports_msl_version(2, 3))
9762 			{
9763 				SPIRV_CROSS_THROW("Using non-constant 0.0 bias() qualifier for sample_compare. This is not supported "
9764 				                  "on macOS prior to MSL 2.3.");
9765 			}
9766 		}
9767 	}
9768 
9769 	// LOD Options
9770 	// Metal does not support LOD for 1D textures.
9771 	if (bias && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D))
9772 	{
9773 		forward = forward && should_forward(bias);
9774 		farg_str += ", bias(" + to_expression(bias) + ")";
9775 	}
9776 
9777 	// Metal does not support LOD for 1D textures.
9778 	if (lod && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D))
9779 	{
9780 		forward = forward && should_forward(lod);
9781 		if (args.base.is_fetch)
9782 		{
9783 			farg_str += ", " + to_expression(lod);
9784 		}
9785 		else
9786 		{
9787 			farg_str += ", level(" + to_expression(lod) + ")";
9788 		}
9789 	}
9790 	else if (args.base.is_fetch && !lod && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D) &&
9791 	         imgtype.image.dim != DimBuffer && !imgtype.image.ms && imgtype.image.sampled != 2)
9792 	{
9793 		// Lod argument is optional in OpImageFetch, but we require a LOD value, pick 0 as the default.
9794 		// Check for sampled type as well, because is_fetch is also used for OpImageRead in MSL.
9795 		farg_str += ", 0";
9796 	}
9797 
9798 	// Metal does not support LOD for 1D textures.
9799 	if ((grad_x || grad_y) && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D))
9800 	{
9801 		forward = forward && should_forward(grad_x);
9802 		forward = forward && should_forward(grad_y);
9803 		string grad_opt;
9804 		switch (imgtype.image.dim)
9805 		{
9806 		case Dim1D:
9807 		case Dim2D:
9808 			grad_opt = "2d";
9809 			break;
9810 		case Dim3D:
9811 			grad_opt = "3d";
9812 			break;
9813 		case DimCube:
9814 			if (imgtype.image.arrayed && msl_options.emulate_cube_array)
9815 				grad_opt = "2d";
9816 			else
9817 				grad_opt = "cube";
9818 			break;
9819 		default:
9820 			grad_opt = "unsupported_gradient_dimension";
9821 			break;
9822 		}
9823 		farg_str += ", gradient" + grad_opt + "(" + to_expression(grad_x) + ", " + to_expression(grad_y) + ")";
9824 	}
9825 
9826 	if (args.min_lod)
9827 	{
9828 		if (!msl_options.supports_msl_version(2, 2))
9829 			SPIRV_CROSS_THROW("min_lod_clamp() is only supported in MSL 2.2+ and up.");
9830 
9831 		forward = forward && should_forward(args.min_lod);
9832 		farg_str += ", min_lod_clamp(" + to_expression(args.min_lod) + ")";
9833 	}
9834 
9835 	// Add offsets
9836 	string offset_expr;
9837 	const SPIRType *offset_type = nullptr;
9838 	if (args.coffset && !args.base.is_fetch)
9839 	{
9840 		forward = forward && should_forward(args.coffset);
9841 		offset_expr = to_expression(args.coffset);
9842 		offset_type = &expression_type(args.coffset);
9843 	}
9844 	else if (args.offset && !args.base.is_fetch)
9845 	{
9846 		forward = forward && should_forward(args.offset);
9847 		offset_expr = to_expression(args.offset);
9848 		offset_type = &expression_type(args.offset);
9849 	}
9850 
9851 	if (!offset_expr.empty())
9852 	{
9853 		switch (imgtype.image.dim)
9854 		{
9855 		case Dim1D:
9856 			if (!msl_options.texture_1D_as_2D)
9857 				break;
9858 			if (offset_type->vecsize > 1)
9859 				offset_expr = enclose_expression(offset_expr) + ".x";
9860 
9861 			farg_str += join(", int2(", offset_expr, ", 0)");
9862 			break;
9863 
9864 		case Dim2D:
9865 			if (offset_type->vecsize > 2)
9866 				offset_expr = enclose_expression(offset_expr) + ".xy";
9867 
9868 			farg_str += ", " + offset_expr;
9869 			break;
9870 
9871 		case Dim3D:
9872 			if (offset_type->vecsize > 3)
9873 				offset_expr = enclose_expression(offset_expr) + ".xyz";
9874 
9875 			farg_str += ", " + offset_expr;
9876 			break;
9877 
9878 		default:
9879 			break;
9880 		}
9881 	}
9882 
9883 	if (args.component)
9884 	{
9885 		// If 2D has gather component, ensure it also has an offset arg
9886 		if (imgtype.image.dim == Dim2D && offset_expr.empty())
9887 			farg_str += ", int2(0)";
9888 
9889 		if (!msl_options.swizzle_texture_samples || is_dynamic_img_sampler)
9890 		{
9891 			forward = forward && should_forward(args.component);
9892 
9893 			uint32_t image_var = 0;
9894 			if (const auto *combined = maybe_get<SPIRCombinedImageSampler>(img))
9895 			{
9896 				if (const auto *img_var = maybe_get_backing_variable(combined->image))
9897 					image_var = img_var->self;
9898 			}
9899 			else if (const auto *var = maybe_get_backing_variable(img))
9900 			{
9901 				image_var = var->self;
9902 			}
9903 
9904 			if (image_var == 0 || !image_is_comparison(expression_type(image_var), image_var))
9905 				farg_str += ", " + to_component_argument(args.component);
9906 		}
9907 	}
9908 
9909 	if (args.sample)
9910 	{
9911 		forward = forward && should_forward(args.sample);
9912 		farg_str += ", ";
9913 		farg_str += to_expression(args.sample);
9914 	}
9915 
9916 	*p_forward = forward;
9917 
9918 	return farg_str;
9919 }
9920 
9921 // If the texture coordinates are floating point, invokes MSL round() function to round them.
round_fp_tex_coords(string tex_coords,bool coord_is_fp)9922 string CompilerMSL::round_fp_tex_coords(string tex_coords, bool coord_is_fp)
9923 {
9924 	return coord_is_fp ? ("round(" + tex_coords + ")") : tex_coords;
9925 }
9926 
9927 // Returns a string to use in an image sampling function argument.
9928 // The ID must be a scalar constant.
to_component_argument(uint32_t id)9929 string CompilerMSL::to_component_argument(uint32_t id)
9930 {
9931 	uint32_t component_index = evaluate_constant_u32(id);
9932 	switch (component_index)
9933 	{
9934 	case 0:
9935 		return "component::x";
9936 	case 1:
9937 		return "component::y";
9938 	case 2:
9939 		return "component::z";
9940 	case 3:
9941 		return "component::w";
9942 
9943 	default:
9944 		SPIRV_CROSS_THROW("The value (" + to_string(component_index) + ") of OpConstant ID " + to_string(id) +
9945 		                  " is not a valid Component index, which must be one of 0, 1, 2, or 3.");
9946 	}
9947 }
9948 
9949 // Establish sampled image as expression object and assign the sampler to it.
emit_sampled_image_op(uint32_t result_type,uint32_t result_id,uint32_t image_id,uint32_t samp_id)9950 void CompilerMSL::emit_sampled_image_op(uint32_t result_type, uint32_t result_id, uint32_t image_id, uint32_t samp_id)
9951 {
9952 	set<SPIRCombinedImageSampler>(result_id, result_type, image_id, samp_id);
9953 }
9954 
to_texture_op(const Instruction & i,bool sparse,bool * forward,SmallVector<uint32_t> & inherited_expressions)9955 string CompilerMSL::to_texture_op(const Instruction &i, bool sparse, bool *forward,
9956                                   SmallVector<uint32_t> &inherited_expressions)
9957 {
9958 	auto *ops = stream(i);
9959 	uint32_t result_type_id = ops[0];
9960 	uint32_t img = ops[2];
9961 	auto &result_type = get<SPIRType>(result_type_id);
9962 	auto op = static_cast<Op>(i.op);
9963 	bool is_gather = (op == OpImageGather || op == OpImageDrefGather);
9964 
9965 	// Bypass pointers because we need the real image struct
9966 	auto &type = expression_type(img);
9967 	auto &imgtype = get<SPIRType>(type.self);
9968 
9969 	const MSLConstexprSampler *constexpr_sampler = nullptr;
9970 	bool is_dynamic_img_sampler = false;
9971 	if (auto *var = maybe_get_backing_variable(img))
9972 	{
9973 		constexpr_sampler = find_constexpr_sampler(var->basevariable ? var->basevariable : VariableID(var->self));
9974 		is_dynamic_img_sampler = has_extended_decoration(var->self, SPIRVCrossDecorationDynamicImageSampler);
9975 	}
9976 
9977 	string expr;
9978 	if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable && !is_dynamic_img_sampler)
9979 	{
9980 		// If this needs sampler Y'CbCr conversion, we need to do some additional
9981 		// processing.
9982 		switch (constexpr_sampler->ycbcr_model)
9983 		{
9984 		case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY:
9985 		case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_IDENTITY:
9986 			// Default
9987 			break;
9988 		case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_709:
9989 			add_spv_func_and_recompile(SPVFuncImplConvertYCbCrBT709);
9990 			expr += "spvConvertYCbCrBT709(";
9991 			break;
9992 		case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_601:
9993 			add_spv_func_and_recompile(SPVFuncImplConvertYCbCrBT601);
9994 			expr += "spvConvertYCbCrBT601(";
9995 			break;
9996 		case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_2020:
9997 			add_spv_func_and_recompile(SPVFuncImplConvertYCbCrBT2020);
9998 			expr += "spvConvertYCbCrBT2020(";
9999 			break;
10000 		default:
10001 			SPIRV_CROSS_THROW("Invalid Y'CbCr model conversion.");
10002 		}
10003 
10004 		if (constexpr_sampler->ycbcr_model != MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY)
10005 		{
10006 			switch (constexpr_sampler->ycbcr_range)
10007 			{
10008 			case MSL_SAMPLER_YCBCR_RANGE_ITU_FULL:
10009 				add_spv_func_and_recompile(SPVFuncImplExpandITUFullRange);
10010 				expr += "spvExpandITUFullRange(";
10011 				break;
10012 			case MSL_SAMPLER_YCBCR_RANGE_ITU_NARROW:
10013 				add_spv_func_and_recompile(SPVFuncImplExpandITUNarrowRange);
10014 				expr += "spvExpandITUNarrowRange(";
10015 				break;
10016 			default:
10017 				SPIRV_CROSS_THROW("Invalid Y'CbCr range.");
10018 			}
10019 		}
10020 	}
10021 	else if (msl_options.swizzle_texture_samples && !is_gather && is_sampled_image_type(imgtype) &&
10022 	         !is_dynamic_img_sampler)
10023 	{
10024 		add_spv_func_and_recompile(SPVFuncImplTextureSwizzle);
10025 		expr += "spvTextureSwizzle(";
10026 	}
10027 
10028 	string inner_expr = CompilerGLSL::to_texture_op(i, sparse, forward, inherited_expressions);
10029 
10030 	if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable && !is_dynamic_img_sampler)
10031 	{
10032 		if (!constexpr_sampler->swizzle_is_identity())
10033 		{
10034 			static const char swizzle_names[] = "rgba";
10035 			if (!constexpr_sampler->swizzle_has_one_or_zero())
10036 			{
10037 				// If we can, do it inline.
10038 				expr += inner_expr + ".";
10039 				for (uint32_t c = 0; c < 4; c++)
10040 				{
10041 					switch (constexpr_sampler->swizzle[c])
10042 					{
10043 					case MSL_COMPONENT_SWIZZLE_IDENTITY:
10044 						expr += swizzle_names[c];
10045 						break;
10046 					case MSL_COMPONENT_SWIZZLE_R:
10047 					case MSL_COMPONENT_SWIZZLE_G:
10048 					case MSL_COMPONENT_SWIZZLE_B:
10049 					case MSL_COMPONENT_SWIZZLE_A:
10050 						expr += swizzle_names[constexpr_sampler->swizzle[c] - MSL_COMPONENT_SWIZZLE_R];
10051 						break;
10052 					default:
10053 						SPIRV_CROSS_THROW("Invalid component swizzle.");
10054 					}
10055 				}
10056 			}
10057 			else
10058 			{
10059 				// Otherwise, we need to emit a temporary and swizzle that.
10060 				uint32_t temp_id = ir.increase_bound_by(1);
10061 				emit_op(result_type_id, temp_id, inner_expr, false);
10062 				for (auto &inherit : inherited_expressions)
10063 					inherit_expression_dependencies(temp_id, inherit);
10064 				inherited_expressions.clear();
10065 				inherited_expressions.push_back(temp_id);
10066 
10067 				switch (op)
10068 				{
10069 				case OpImageSampleDrefImplicitLod:
10070 				case OpImageSampleImplicitLod:
10071 				case OpImageSampleProjImplicitLod:
10072 				case OpImageSampleProjDrefImplicitLod:
10073 					register_control_dependent_expression(temp_id);
10074 					break;
10075 
10076 				default:
10077 					break;
10078 				}
10079 				expr += type_to_glsl(result_type) + "(";
10080 				for (uint32_t c = 0; c < 4; c++)
10081 				{
10082 					switch (constexpr_sampler->swizzle[c])
10083 					{
10084 					case MSL_COMPONENT_SWIZZLE_IDENTITY:
10085 						expr += to_expression(temp_id) + "." + swizzle_names[c];
10086 						break;
10087 					case MSL_COMPONENT_SWIZZLE_ZERO:
10088 						expr += "0";
10089 						break;
10090 					case MSL_COMPONENT_SWIZZLE_ONE:
10091 						expr += "1";
10092 						break;
10093 					case MSL_COMPONENT_SWIZZLE_R:
10094 					case MSL_COMPONENT_SWIZZLE_G:
10095 					case MSL_COMPONENT_SWIZZLE_B:
10096 					case MSL_COMPONENT_SWIZZLE_A:
10097 						expr += to_expression(temp_id) + "." +
10098 						        swizzle_names[constexpr_sampler->swizzle[c] - MSL_COMPONENT_SWIZZLE_R];
10099 						break;
10100 					default:
10101 						SPIRV_CROSS_THROW("Invalid component swizzle.");
10102 					}
10103 					if (c < 3)
10104 						expr += ", ";
10105 				}
10106 				expr += ")";
10107 			}
10108 		}
10109 		else
10110 			expr += inner_expr;
10111 		if (constexpr_sampler->ycbcr_model != MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY)
10112 		{
10113 			expr += join(", ", constexpr_sampler->bpc, ")");
10114 			if (constexpr_sampler->ycbcr_model != MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_IDENTITY)
10115 				expr += ")";
10116 		}
10117 	}
10118 	else
10119 	{
10120 		expr += inner_expr;
10121 		if (msl_options.swizzle_texture_samples && !is_gather && is_sampled_image_type(imgtype) &&
10122 		    !is_dynamic_img_sampler)
10123 		{
10124 			// Add the swizzle constant from the swizzle buffer.
10125 			expr += ", " + to_swizzle_expression(img) + ")";
10126 			used_swizzle_buffer = true;
10127 		}
10128 	}
10129 
10130 	return expr;
10131 }
10132 
create_swizzle(MSLComponentSwizzle swizzle)10133 static string create_swizzle(MSLComponentSwizzle swizzle)
10134 {
10135 	switch (swizzle)
10136 	{
10137 	case MSL_COMPONENT_SWIZZLE_IDENTITY:
10138 		return "spvSwizzle::none";
10139 	case MSL_COMPONENT_SWIZZLE_ZERO:
10140 		return "spvSwizzle::zero";
10141 	case MSL_COMPONENT_SWIZZLE_ONE:
10142 		return "spvSwizzle::one";
10143 	case MSL_COMPONENT_SWIZZLE_R:
10144 		return "spvSwizzle::red";
10145 	case MSL_COMPONENT_SWIZZLE_G:
10146 		return "spvSwizzle::green";
10147 	case MSL_COMPONENT_SWIZZLE_B:
10148 		return "spvSwizzle::blue";
10149 	case MSL_COMPONENT_SWIZZLE_A:
10150 		return "spvSwizzle::alpha";
10151 	default:
10152 		SPIRV_CROSS_THROW("Invalid component swizzle.");
10153 	}
10154 }
10155 
10156 // Returns a string representation of the ID, usable as a function arg.
10157 // Manufacture automatic sampler arg for SampledImage texture.
to_func_call_arg(const SPIRFunction::Parameter & arg,uint32_t id)10158 string CompilerMSL::to_func_call_arg(const SPIRFunction::Parameter &arg, uint32_t id)
10159 {
10160 	string arg_str;
10161 
10162 	auto &type = expression_type(id);
10163 	bool is_dynamic_img_sampler = has_extended_decoration(arg.id, SPIRVCrossDecorationDynamicImageSampler);
10164 	// If the argument *itself* is a "dynamic" combined-image sampler, then we can just pass that around.
10165 	bool arg_is_dynamic_img_sampler = has_extended_decoration(id, SPIRVCrossDecorationDynamicImageSampler);
10166 	if (is_dynamic_img_sampler && !arg_is_dynamic_img_sampler)
10167 		arg_str = join("spvDynamicImageSampler<", type_to_glsl(get<SPIRType>(type.image.type)), ">(");
10168 
10169 	auto *c = maybe_get<SPIRConstant>(id);
10170 	if (msl_options.force_native_arrays && c && !get<SPIRType>(c->constant_type).array.empty())
10171 	{
10172 		// If we are passing a constant array directly to a function for some reason,
10173 		// the callee will expect an argument in thread const address space
10174 		// (since we can only bind to arrays with references in MSL).
10175 		// To resolve this, we must emit a copy in this address space.
10176 		// This kind of code gen should be rare enough that performance is not a real concern.
10177 		// Inline the SPIR-V to avoid this kind of suboptimal codegen.
10178 		//
10179 		// We risk calling this inside a continue block (invalid code),
10180 		// so just create a thread local copy in the current function.
10181 		arg_str = join("_", id, "_array_copy");
10182 		auto &constants = current_function->constant_arrays_needed_on_stack;
10183 		auto itr = find(begin(constants), end(constants), ID(id));
10184 		if (itr == end(constants))
10185 		{
10186 			force_recompile();
10187 			constants.push_back(id);
10188 		}
10189 	}
10190 	else
10191 		arg_str += CompilerGLSL::to_func_call_arg(arg, id);
10192 
10193 	// Need to check the base variable in case we need to apply a qualified alias.
10194 	uint32_t var_id = 0;
10195 	auto *var = maybe_get<SPIRVariable>(id);
10196 	if (var)
10197 		var_id = var->basevariable;
10198 
10199 	if (!arg_is_dynamic_img_sampler)
10200 	{
10201 		auto *constexpr_sampler = find_constexpr_sampler(var_id ? var_id : id);
10202 		if (type.basetype == SPIRType::SampledImage)
10203 		{
10204 			// Manufacture automatic plane args for multiplanar texture
10205 			uint32_t planes = 1;
10206 			if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
10207 			{
10208 				planes = constexpr_sampler->planes;
10209 				// If this parameter isn't aliasing a global, then we need to use
10210 				// the special "dynamic image-sampler" class to pass it--and we need
10211 				// to use it for *every* non-alias parameter, in case a combined
10212 				// image-sampler with a Y'CbCr conversion is passed. Hopefully, this
10213 				// pathological case is so rare that it should never be hit in practice.
10214 				if (!arg.alias_global_variable)
10215 					add_spv_func_and_recompile(SPVFuncImplDynamicImageSampler);
10216 			}
10217 			for (uint32_t i = 1; i < planes; i++)
10218 				arg_str += join(", ", CompilerGLSL::to_func_call_arg(arg, id), plane_name_suffix, i);
10219 			// Manufacture automatic sampler arg if the arg is a SampledImage texture.
10220 			if (type.image.dim != DimBuffer)
10221 				arg_str += ", " + to_sampler_expression(var_id ? var_id : id);
10222 
10223 			// Add sampler Y'CbCr conversion info if we have it
10224 			if (is_dynamic_img_sampler && constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
10225 			{
10226 				SmallVector<string> samp_args;
10227 
10228 				switch (constexpr_sampler->resolution)
10229 				{
10230 				case MSL_FORMAT_RESOLUTION_444:
10231 					// Default
10232 					break;
10233 				case MSL_FORMAT_RESOLUTION_422:
10234 					samp_args.push_back("spvFormatResolution::_422");
10235 					break;
10236 				case MSL_FORMAT_RESOLUTION_420:
10237 					samp_args.push_back("spvFormatResolution::_420");
10238 					break;
10239 				default:
10240 					SPIRV_CROSS_THROW("Invalid format resolution.");
10241 				}
10242 
10243 				if (constexpr_sampler->chroma_filter != MSL_SAMPLER_FILTER_NEAREST)
10244 					samp_args.push_back("spvChromaFilter::linear");
10245 
10246 				if (constexpr_sampler->x_chroma_offset != MSL_CHROMA_LOCATION_COSITED_EVEN)
10247 					samp_args.push_back("spvXChromaLocation::midpoint");
10248 				if (constexpr_sampler->y_chroma_offset != MSL_CHROMA_LOCATION_COSITED_EVEN)
10249 					samp_args.push_back("spvYChromaLocation::midpoint");
10250 				switch (constexpr_sampler->ycbcr_model)
10251 				{
10252 				case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY:
10253 					// Default
10254 					break;
10255 				case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_IDENTITY:
10256 					samp_args.push_back("spvYCbCrModelConversion::ycbcr_identity");
10257 					break;
10258 				case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_709:
10259 					samp_args.push_back("spvYCbCrModelConversion::ycbcr_bt_709");
10260 					break;
10261 				case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_601:
10262 					samp_args.push_back("spvYCbCrModelConversion::ycbcr_bt_601");
10263 					break;
10264 				case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_2020:
10265 					samp_args.push_back("spvYCbCrModelConversion::ycbcr_bt_2020");
10266 					break;
10267 				default:
10268 					SPIRV_CROSS_THROW("Invalid Y'CbCr model conversion.");
10269 				}
10270 				if (constexpr_sampler->ycbcr_range != MSL_SAMPLER_YCBCR_RANGE_ITU_FULL)
10271 					samp_args.push_back("spvYCbCrRange::itu_narrow");
10272 				samp_args.push_back(join("spvComponentBits(", constexpr_sampler->bpc, ")"));
10273 				arg_str += join(", spvYCbCrSampler(", merge(samp_args), ")");
10274 			}
10275 		}
10276 
10277 		if (is_dynamic_img_sampler && constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
10278 			arg_str += join(", (uint(", create_swizzle(constexpr_sampler->swizzle[3]), ") << 24) | (uint(",
10279 			                create_swizzle(constexpr_sampler->swizzle[2]), ") << 16) | (uint(",
10280 			                create_swizzle(constexpr_sampler->swizzle[1]), ") << 8) | uint(",
10281 			                create_swizzle(constexpr_sampler->swizzle[0]), ")");
10282 		else if (msl_options.swizzle_texture_samples && has_sampled_images && is_sampled_image_type(type))
10283 			arg_str += ", " + to_swizzle_expression(var_id ? var_id : id);
10284 
10285 		if (buffers_requiring_array_length.count(var_id))
10286 			arg_str += ", " + to_buffer_size_expression(var_id ? var_id : id);
10287 
10288 		if (is_dynamic_img_sampler)
10289 			arg_str += ")";
10290 	}
10291 
10292 	// Emulate texture2D atomic operations
10293 	auto *backing_var = maybe_get_backing_variable(var_id);
10294 	if (backing_var && atomic_image_vars.count(backing_var->self))
10295 	{
10296 		arg_str += ", " + to_expression(var_id) + "_atomic";
10297 	}
10298 
10299 	return arg_str;
10300 }
10301 
10302 // If the ID represents a sampled image that has been assigned a sampler already,
10303 // generate an expression for the sampler, otherwise generate a fake sampler name
10304 // by appending a suffix to the expression constructed from the ID.
to_sampler_expression(uint32_t id)10305 string CompilerMSL::to_sampler_expression(uint32_t id)
10306 {
10307 	auto *combined = maybe_get<SPIRCombinedImageSampler>(id);
10308 	auto expr = to_expression(combined ? combined->image : VariableID(id));
10309 	auto index = expr.find_first_of('[');
10310 
10311 	uint32_t samp_id = 0;
10312 	if (combined)
10313 		samp_id = combined->sampler;
10314 
10315 	if (index == string::npos)
10316 		return samp_id ? to_expression(samp_id) : expr + sampler_name_suffix;
10317 	else
10318 	{
10319 		auto image_expr = expr.substr(0, index);
10320 		auto array_expr = expr.substr(index);
10321 		return samp_id ? to_expression(samp_id) : (image_expr + sampler_name_suffix + array_expr);
10322 	}
10323 }
10324 
to_swizzle_expression(uint32_t id)10325 string CompilerMSL::to_swizzle_expression(uint32_t id)
10326 {
10327 	auto *combined = maybe_get<SPIRCombinedImageSampler>(id);
10328 
10329 	auto expr = to_expression(combined ? combined->image : VariableID(id));
10330 	auto index = expr.find_first_of('[');
10331 
10332 	// If an image is part of an argument buffer translate this to a legal identifier.
10333 	string::size_type period = 0;
10334 	while ((period = expr.find_first_of('.', period)) != string::npos && period < index)
10335 		expr[period] = '_';
10336 
10337 	if (index == string::npos)
10338 		return expr + swizzle_name_suffix;
10339 	else
10340 	{
10341 		auto image_expr = expr.substr(0, index);
10342 		auto array_expr = expr.substr(index);
10343 		return image_expr + swizzle_name_suffix + array_expr;
10344 	}
10345 }
10346 
to_buffer_size_expression(uint32_t id)10347 string CompilerMSL::to_buffer_size_expression(uint32_t id)
10348 {
10349 	auto expr = to_expression(id);
10350 	auto index = expr.find_first_of('[');
10351 
10352 	// This is quite crude, but we need to translate the reference name (*spvDescriptorSetN.name) to
10353 	// the pointer expression spvDescriptorSetN.name to make a reasonable expression here.
10354 	// This only happens if we have argument buffers and we are using OpArrayLength on a lone SSBO in that set.
10355 	if (expr.size() >= 3 && expr[0] == '(' && expr[1] == '*')
10356 		expr = address_of_expression(expr);
10357 
10358 	// If a buffer is part of an argument buffer translate this to a legal identifier.
10359 	for (auto &c : expr)
10360 		if (c == '.')
10361 			c = '_';
10362 
10363 	if (index == string::npos)
10364 		return expr + buffer_size_name_suffix;
10365 	else
10366 	{
10367 		auto buffer_expr = expr.substr(0, index);
10368 		auto array_expr = expr.substr(index);
10369 		return buffer_expr + buffer_size_name_suffix + array_expr;
10370 	}
10371 }
10372 
10373 // Checks whether the type is a Block all of whose members have DecorationPatch.
is_patch_block(const SPIRType & type)10374 bool CompilerMSL::is_patch_block(const SPIRType &type)
10375 {
10376 	if (!has_decoration(type.self, DecorationBlock))
10377 		return false;
10378 
10379 	for (uint32_t i = 0; i < type.member_types.size(); i++)
10380 	{
10381 		if (!has_member_decoration(type.self, i, DecorationPatch))
10382 			return false;
10383 	}
10384 
10385 	return true;
10386 }
10387 
10388 // Checks whether the ID is a row_major matrix that requires conversion before use
is_non_native_row_major_matrix(uint32_t id)10389 bool CompilerMSL::is_non_native_row_major_matrix(uint32_t id)
10390 {
10391 	auto *e = maybe_get<SPIRExpression>(id);
10392 	if (e)
10393 		return e->need_transpose;
10394 	else
10395 		return has_decoration(id, DecorationRowMajor);
10396 }
10397 
10398 // Checks whether the member is a row_major matrix that requires conversion before use
member_is_non_native_row_major_matrix(const SPIRType & type,uint32_t index)10399 bool CompilerMSL::member_is_non_native_row_major_matrix(const SPIRType &type, uint32_t index)
10400 {
10401 	return has_member_decoration(type.self, index, DecorationRowMajor);
10402 }
10403 
convert_row_major_matrix(string exp_str,const SPIRType & exp_type,uint32_t physical_type_id,bool is_packed)10404 string CompilerMSL::convert_row_major_matrix(string exp_str, const SPIRType &exp_type, uint32_t physical_type_id,
10405                                              bool is_packed)
10406 {
10407 	if (!is_matrix(exp_type))
10408 	{
10409 		return CompilerGLSL::convert_row_major_matrix(move(exp_str), exp_type, physical_type_id, is_packed);
10410 	}
10411 	else
10412 	{
10413 		strip_enclosed_expression(exp_str);
10414 		if (physical_type_id != 0 || is_packed)
10415 			exp_str = unpack_expression_type(exp_str, exp_type, physical_type_id, is_packed, true);
10416 		return join("transpose(", exp_str, ")");
10417 	}
10418 }
10419 
10420 // Called automatically at the end of the entry point function
emit_fixup()10421 void CompilerMSL::emit_fixup()
10422 {
10423 	if (is_vertex_like_shader() && stage_out_var_id && !qual_pos_var_name.empty() && !capture_output_to_buffer)
10424 	{
10425 		if (options.vertex.fixup_clipspace)
10426 			statement(qual_pos_var_name, ".z = (", qual_pos_var_name, ".z + ", qual_pos_var_name,
10427 			          ".w) * 0.5;       // Adjust clip-space for Metal");
10428 
10429 		if (options.vertex.flip_vert_y)
10430 			statement(qual_pos_var_name, ".y = -(", qual_pos_var_name, ".y);", "    // Invert Y-axis for Metal");
10431 	}
10432 }
10433 
10434 // Return a string defining a structure member, with padding and packing.
to_struct_member(const SPIRType & type,uint32_t member_type_id,uint32_t index,const string & qualifier)10435 string CompilerMSL::to_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index,
10436                                      const string &qualifier)
10437 {
10438 	if (member_is_remapped_physical_type(type, index))
10439 		member_type_id = get_extended_member_decoration(type.self, index, SPIRVCrossDecorationPhysicalTypeID);
10440 	auto &physical_type = get<SPIRType>(member_type_id);
10441 
10442 	// If this member is packed, mark it as so.
10443 	string pack_pfx;
10444 
10445 	// Allow Metal to use the array<T> template to make arrays a value type
10446 	uint32_t orig_id = 0;
10447 	if (has_extended_member_decoration(type.self, index, SPIRVCrossDecorationInterfaceOrigID))
10448 		orig_id = get_extended_member_decoration(type.self, index, SPIRVCrossDecorationInterfaceOrigID);
10449 
10450 	bool row_major = false;
10451 	if (is_matrix(physical_type))
10452 		row_major = has_member_decoration(type.self, index, DecorationRowMajor);
10453 
10454 	SPIRType row_major_physical_type;
10455 	const SPIRType *declared_type = &physical_type;
10456 
10457 	// If a struct is being declared with physical layout,
10458 	// do not use array<T> wrappers.
10459 	// This avoids a lot of complicated cases with packed vectors and matrices,
10460 	// and generally we cannot copy full arrays in and out of buffers into Function
10461 	// address space.
10462 	// Array of resources should also be declared as builtin arrays.
10463 	if (has_member_decoration(type.self, index, DecorationOffset))
10464 		is_using_builtin_array = true;
10465 	else if (has_extended_member_decoration(type.self, index, SPIRVCrossDecorationResourceIndexPrimary))
10466 		is_using_builtin_array = true;
10467 
10468 	if (member_is_packed_physical_type(type, index))
10469 	{
10470 		// If we're packing a matrix, output an appropriate typedef
10471 		if (physical_type.basetype == SPIRType::Struct)
10472 		{
10473 			SPIRV_CROSS_THROW("Cannot emit a packed struct currently.");
10474 		}
10475 		else if (is_matrix(physical_type))
10476 		{
10477 			uint32_t rows = physical_type.vecsize;
10478 			uint32_t cols = physical_type.columns;
10479 			pack_pfx = "packed_";
10480 			if (row_major)
10481 			{
10482 				// These are stored transposed.
10483 				rows = physical_type.columns;
10484 				cols = physical_type.vecsize;
10485 				pack_pfx = "packed_rm_";
10486 			}
10487 			string base_type = physical_type.width == 16 ? "half" : "float";
10488 			string td_line = "typedef ";
10489 			td_line += "packed_" + base_type + to_string(rows);
10490 			td_line += " " + pack_pfx;
10491 			// Use the actual matrix size here.
10492 			td_line += base_type + to_string(physical_type.columns) + "x" + to_string(physical_type.vecsize);
10493 			td_line += "[" + to_string(cols) + "]";
10494 			td_line += ";";
10495 			add_typedef_line(td_line);
10496 		}
10497 		else if (!is_scalar(physical_type)) // scalar type is already packed.
10498 			pack_pfx = "packed_";
10499 	}
10500 	else if (row_major)
10501 	{
10502 		// Need to declare type with flipped vecsize/columns.
10503 		row_major_physical_type = physical_type;
10504 		swap(row_major_physical_type.vecsize, row_major_physical_type.columns);
10505 		declared_type = &row_major_physical_type;
10506 	}
10507 
10508 	// Very specifically, image load-store in argument buffers are disallowed on MSL on iOS.
10509 	if (msl_options.is_ios() && physical_type.basetype == SPIRType::Image && physical_type.image.sampled == 2)
10510 	{
10511 		if (!has_decoration(orig_id, DecorationNonWritable))
10512 			SPIRV_CROSS_THROW("Writable images are not allowed in argument buffers on iOS.");
10513 	}
10514 
10515 	// Array information is baked into these types.
10516 	string array_type;
10517 	if (physical_type.basetype != SPIRType::Image && physical_type.basetype != SPIRType::Sampler &&
10518 	    physical_type.basetype != SPIRType::SampledImage)
10519 	{
10520 		BuiltIn builtin = BuiltInMax;
10521 
10522 		// Special handling. In [[stage_out]] or [[stage_in]] blocks,
10523 		// we need flat arrays, but if we're somehow declaring gl_PerVertex for constant array reasons, we want
10524 		// template array types to be declared.
10525 		bool is_ib_in_out =
10526 				((stage_out_var_id && get_stage_out_struct_type().self == type.self &&
10527 				  variable_storage_requires_stage_io(StorageClassOutput)) ||
10528 				 (stage_in_var_id && get_stage_in_struct_type().self == type.self &&
10529 				  variable_storage_requires_stage_io(StorageClassInput)));
10530 		if (is_ib_in_out && is_member_builtin(type, index, &builtin))
10531 			is_using_builtin_array = true;
10532 		array_type = type_to_array_glsl(physical_type);
10533 	}
10534 
10535 	auto result = join(pack_pfx, type_to_glsl(*declared_type, orig_id), " ", qualifier, to_member_name(type, index),
10536 	                   member_attribute_qualifier(type, index), array_type, ";");
10537 
10538 	is_using_builtin_array = false;
10539 	return result;
10540 }
10541 
10542 // Emit a structure member, padding and packing to maintain the correct memeber alignments.
emit_struct_member(const SPIRType & type,uint32_t member_type_id,uint32_t index,const string & qualifier,uint32_t)10543 void CompilerMSL::emit_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index,
10544                                      const string &qualifier, uint32_t)
10545 {
10546 	// If this member requires padding to maintain its declared offset, emit a dummy padding member before it.
10547 	if (has_extended_member_decoration(type.self, index, SPIRVCrossDecorationPaddingTarget))
10548 	{
10549 		uint32_t pad_len = get_extended_member_decoration(type.self, index, SPIRVCrossDecorationPaddingTarget);
10550 		statement("char _m", index, "_pad", "[", pad_len, "];");
10551 	}
10552 
10553 	// Handle HLSL-style 0-based vertex/instance index.
10554 	builtin_declaration = true;
10555 	statement(to_struct_member(type, member_type_id, index, qualifier));
10556 	builtin_declaration = false;
10557 }
10558 
emit_struct_padding_target(const SPIRType & type)10559 void CompilerMSL::emit_struct_padding_target(const SPIRType &type)
10560 {
10561 	uint32_t struct_size = get_declared_struct_size_msl(type, true, true);
10562 	uint32_t target_size = get_extended_decoration(type.self, SPIRVCrossDecorationPaddingTarget);
10563 	if (target_size < struct_size)
10564 		SPIRV_CROSS_THROW("Cannot pad with negative bytes.");
10565 	else if (target_size > struct_size)
10566 		statement("char _m0_final_padding[", target_size - struct_size, "];");
10567 }
10568 
10569 // Return a MSL qualifier for the specified function attribute member
member_attribute_qualifier(const SPIRType & type,uint32_t index)10570 string CompilerMSL::member_attribute_qualifier(const SPIRType &type, uint32_t index)
10571 {
10572 	auto &execution = get_entry_point();
10573 
10574 	uint32_t mbr_type_id = type.member_types[index];
10575 	auto &mbr_type = get<SPIRType>(mbr_type_id);
10576 
10577 	BuiltIn builtin = BuiltInMax;
10578 	bool is_builtin = is_member_builtin(type, index, &builtin);
10579 
10580 	if (has_extended_member_decoration(type.self, index, SPIRVCrossDecorationResourceIndexPrimary))
10581 	{
10582 		string quals = join(
10583 		    " [[id(", get_extended_member_decoration(type.self, index, SPIRVCrossDecorationResourceIndexPrimary), ")");
10584 		if (interlocked_resources.count(
10585 		        get_extended_member_decoration(type.self, index, SPIRVCrossDecorationInterfaceOrigID)))
10586 			quals += ", raster_order_group(0)";
10587 		quals += "]]";
10588 		return quals;
10589 	}
10590 
10591 	// Vertex function inputs
10592 	if (execution.model == ExecutionModelVertex && type.storage == StorageClassInput)
10593 	{
10594 		if (is_builtin)
10595 		{
10596 			switch (builtin)
10597 			{
10598 			case BuiltInVertexId:
10599 			case BuiltInVertexIndex:
10600 			case BuiltInBaseVertex:
10601 			case BuiltInInstanceId:
10602 			case BuiltInInstanceIndex:
10603 			case BuiltInBaseInstance:
10604 				if (msl_options.vertex_for_tessellation)
10605 					return "";
10606 				return string(" [[") + builtin_qualifier(builtin) + "]]";
10607 
10608 			case BuiltInDrawIndex:
10609 				SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
10610 
10611 			default:
10612 				return "";
10613 			}
10614 		}
10615 
10616 		uint32_t locn;
10617 		if (is_builtin)
10618 			locn = get_or_allocate_builtin_input_member_location(builtin, type.self, index);
10619 		else
10620 			locn = get_member_location(type.self, index);
10621 
10622 		if (locn != k_unknown_location)
10623 			return string(" [[attribute(") + convert_to_string(locn) + ")]]";
10624 	}
10625 
10626 	// Vertex and tessellation evaluation function outputs
10627 	if (((execution.model == ExecutionModelVertex && !msl_options.vertex_for_tessellation) ||
10628 	     execution.model == ExecutionModelTessellationEvaluation) &&
10629 	    type.storage == StorageClassOutput)
10630 	{
10631 		if (is_builtin)
10632 		{
10633 			switch (builtin)
10634 			{
10635 			case BuiltInPointSize:
10636 				// Only mark the PointSize builtin if really rendering points.
10637 				// Some shaders may include a PointSize builtin even when used to render
10638 				// non-point topologies, and Metal will reject this builtin when compiling
10639 				// the shader into a render pipeline that uses a non-point topology.
10640 				return msl_options.enable_point_size_builtin ? (string(" [[") + builtin_qualifier(builtin) + "]]") : "";
10641 
10642 			case BuiltInViewportIndex:
10643 				if (!msl_options.supports_msl_version(2, 0))
10644 					SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
10645 				/* fallthrough */
10646 			case BuiltInPosition:
10647 			case BuiltInLayer:
10648 				return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
10649 
10650 			case BuiltInClipDistance:
10651 				if (has_member_decoration(type.self, index, DecorationIndex))
10652 					return join(" [[user(clip", get_member_decoration(type.self, index, DecorationIndex), ")]]");
10653 				else
10654 					return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
10655 
10656 			case BuiltInCullDistance:
10657 				if (has_member_decoration(type.self, index, DecorationIndex))
10658 					return join(" [[user(cull", get_member_decoration(type.self, index, DecorationIndex), ")]]");
10659 				else
10660 					return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
10661 
10662 			default:
10663 				return "";
10664 			}
10665 		}
10666 		uint32_t comp;
10667 		uint32_t locn = get_member_location(type.self, index, &comp);
10668 		if (locn != k_unknown_location)
10669 		{
10670 			if (comp != k_unknown_component)
10671 				return string(" [[user(locn") + convert_to_string(locn) + "_" + convert_to_string(comp) + ")]]";
10672 			else
10673 				return string(" [[user(locn") + convert_to_string(locn) + ")]]";
10674 		}
10675 	}
10676 
10677 	// Tessellation control function inputs
10678 	if (execution.model == ExecutionModelTessellationControl && type.storage == StorageClassInput)
10679 	{
10680 		if (is_builtin)
10681 		{
10682 			switch (builtin)
10683 			{
10684 			case BuiltInInvocationId:
10685 			case BuiltInPrimitiveId:
10686 				if (msl_options.multi_patch_workgroup)
10687 					return "";
10688 				return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
10689 			case BuiltInSubgroupLocalInvocationId: // FIXME: Should work in any stage
10690 			case BuiltInSubgroupSize: // FIXME: Should work in any stage
10691 				if (msl_options.emulate_subgroups)
10692 					return "";
10693 				return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
10694 			case BuiltInPatchVertices:
10695 				return "";
10696 			// Others come from stage input.
10697 			default:
10698 				break;
10699 			}
10700 		}
10701 		if (msl_options.multi_patch_workgroup)
10702 			return "";
10703 
10704 		uint32_t locn;
10705 		if (is_builtin)
10706 			locn = get_or_allocate_builtin_input_member_location(builtin, type.self, index);
10707 		else
10708 			locn = get_member_location(type.self, index);
10709 
10710 		if (locn != k_unknown_location)
10711 			return string(" [[attribute(") + convert_to_string(locn) + ")]]";
10712 	}
10713 
10714 	// Tessellation control function outputs
10715 	if (execution.model == ExecutionModelTessellationControl && type.storage == StorageClassOutput)
10716 	{
10717 		// For this type of shader, we always arrange for it to capture its
10718 		// output to a buffer. For this reason, qualifiers are irrelevant here.
10719 		return "";
10720 	}
10721 
10722 	// Tessellation evaluation function inputs
10723 	if (execution.model == ExecutionModelTessellationEvaluation && type.storage == StorageClassInput)
10724 	{
10725 		if (is_builtin)
10726 		{
10727 			switch (builtin)
10728 			{
10729 			case BuiltInPrimitiveId:
10730 			case BuiltInTessCoord:
10731 				return string(" [[") + builtin_qualifier(builtin) + "]]";
10732 			case BuiltInPatchVertices:
10733 				return "";
10734 			// Others come from stage input.
10735 			default:
10736 				break;
10737 			}
10738 		}
10739 		// The special control point array must not be marked with an attribute.
10740 		if (get_type(type.member_types[index]).basetype == SPIRType::ControlPointArray)
10741 			return "";
10742 
10743 		uint32_t locn;
10744 		if (is_builtin)
10745 			locn = get_or_allocate_builtin_input_member_location(builtin, type.self, index);
10746 		else
10747 			locn = get_member_location(type.self, index);
10748 
10749 		if (locn != k_unknown_location)
10750 			return string(" [[attribute(") + convert_to_string(locn) + ")]]";
10751 	}
10752 
10753 	// Tessellation evaluation function outputs were handled above.
10754 
10755 	// Fragment function inputs
10756 	if (execution.model == ExecutionModelFragment && type.storage == StorageClassInput)
10757 	{
10758 		string quals;
10759 		if (is_builtin)
10760 		{
10761 			switch (builtin)
10762 			{
10763 			case BuiltInViewIndex:
10764 				if (!msl_options.multiview || !msl_options.multiview_layered_rendering)
10765 					break;
10766 				/* fallthrough */
10767 			case BuiltInFrontFacing:
10768 			case BuiltInPointCoord:
10769 			case BuiltInFragCoord:
10770 			case BuiltInSampleId:
10771 			case BuiltInSampleMask:
10772 			case BuiltInLayer:
10773 			case BuiltInBaryCoordNV:
10774 			case BuiltInBaryCoordNoPerspNV:
10775 				quals = builtin_qualifier(builtin);
10776 				break;
10777 
10778 			case BuiltInClipDistance:
10779 				return join(" [[user(clip", get_member_decoration(type.self, index, DecorationIndex), ")]]");
10780 			case BuiltInCullDistance:
10781 				return join(" [[user(cull", get_member_decoration(type.self, index, DecorationIndex), ")]]");
10782 
10783 			default:
10784 				break;
10785 			}
10786 		}
10787 		else
10788 		{
10789 			uint32_t comp;
10790 			uint32_t locn = get_member_location(type.self, index, &comp);
10791 			if (locn != k_unknown_location)
10792 			{
10793 				// For user-defined attributes, this is fine. From Vulkan spec:
10794 				// A user-defined output variable is considered to match an input variable in the subsequent stage if
10795 				// the two variables are declared with the same Location and Component decoration and match in type
10796 				// and decoration, except that interpolation decorations are not required to match. For the purposes
10797 				// of interface matching, variables declared without a Component decoration are considered to have a
10798 				// Component decoration of zero.
10799 
10800 				if (comp != k_unknown_component && comp != 0)
10801 					quals = string("user(locn") + convert_to_string(locn) + "_" + convert_to_string(comp) + ")";
10802 				else
10803 					quals = string("user(locn") + convert_to_string(locn) + ")";
10804 			}
10805 		}
10806 
10807 		if (builtin == BuiltInBaryCoordNV || builtin == BuiltInBaryCoordNoPerspNV)
10808 		{
10809 			if (has_member_decoration(type.self, index, DecorationFlat) ||
10810 			    has_member_decoration(type.self, index, DecorationCentroid) ||
10811 			    has_member_decoration(type.self, index, DecorationSample) ||
10812 			    has_member_decoration(type.self, index, DecorationNoPerspective))
10813 			{
10814 				// NoPerspective is baked into the builtin type.
10815 				SPIRV_CROSS_THROW(
10816 				    "Flat, Centroid, Sample, NoPerspective decorations are not supported for BaryCoord inputs.");
10817 			}
10818 		}
10819 
10820 		// Don't bother decorating integers with the 'flat' attribute; it's
10821 		// the default (in fact, the only option). Also don't bother with the
10822 		// FragCoord builtin; it's always noperspective on Metal.
10823 		if (!type_is_integral(mbr_type) && (!is_builtin || builtin != BuiltInFragCoord))
10824 		{
10825 			if (has_member_decoration(type.self, index, DecorationFlat))
10826 			{
10827 				if (!quals.empty())
10828 					quals += ", ";
10829 				quals += "flat";
10830 			}
10831 			else if (has_member_decoration(type.self, index, DecorationCentroid))
10832 			{
10833 				if (!quals.empty())
10834 					quals += ", ";
10835 				if (has_member_decoration(type.self, index, DecorationNoPerspective))
10836 					quals += "centroid_no_perspective";
10837 				else
10838 					quals += "centroid_perspective";
10839 			}
10840 			else if (has_member_decoration(type.self, index, DecorationSample))
10841 			{
10842 				if (!quals.empty())
10843 					quals += ", ";
10844 				if (has_member_decoration(type.self, index, DecorationNoPerspective))
10845 					quals += "sample_no_perspective";
10846 				else
10847 					quals += "sample_perspective";
10848 			}
10849 			else if (has_member_decoration(type.self, index, DecorationNoPerspective))
10850 			{
10851 				if (!quals.empty())
10852 					quals += ", ";
10853 				quals += "center_no_perspective";
10854 			}
10855 		}
10856 
10857 		if (!quals.empty())
10858 			return " [[" + quals + "]]";
10859 	}
10860 
10861 	// Fragment function outputs
10862 	if (execution.model == ExecutionModelFragment && type.storage == StorageClassOutput)
10863 	{
10864 		if (is_builtin)
10865 		{
10866 			switch (builtin)
10867 			{
10868 			case BuiltInFragStencilRefEXT:
10869 				// Similar to PointSize, only mark FragStencilRef if there's a stencil buffer.
10870 				// Some shaders may include a FragStencilRef builtin even when used to render
10871 				// without a stencil attachment, and Metal will reject this builtin
10872 				// when compiling the shader into a render pipeline that does not set
10873 				// stencilAttachmentPixelFormat.
10874 				if (!msl_options.enable_frag_stencil_ref_builtin)
10875 					return "";
10876 				if (!msl_options.supports_msl_version(2, 1))
10877 					SPIRV_CROSS_THROW("Stencil export only supported in MSL 2.1 and up.");
10878 				return string(" [[") + builtin_qualifier(builtin) + "]]";
10879 
10880 			case BuiltInFragDepth:
10881 				// Ditto FragDepth.
10882 				if (!msl_options.enable_frag_depth_builtin)
10883 					return "";
10884 				/* fallthrough */
10885 			case BuiltInSampleMask:
10886 				return string(" [[") + builtin_qualifier(builtin) + "]]";
10887 
10888 			default:
10889 				return "";
10890 			}
10891 		}
10892 		uint32_t locn = get_member_location(type.self, index);
10893 		// Metal will likely complain about missing color attachments, too.
10894 		if (locn != k_unknown_location && !(msl_options.enable_frag_output_mask & (1 << locn)))
10895 			return "";
10896 		if (locn != k_unknown_location && has_member_decoration(type.self, index, DecorationIndex))
10897 			return join(" [[color(", locn, "), index(", get_member_decoration(type.self, index, DecorationIndex),
10898 			            ")]]");
10899 		else if (locn != k_unknown_location)
10900 			return join(" [[color(", locn, ")]]");
10901 		else if (has_member_decoration(type.self, index, DecorationIndex))
10902 			return join(" [[index(", get_member_decoration(type.self, index, DecorationIndex), ")]]");
10903 		else
10904 			return "";
10905 	}
10906 
10907 	// Compute function inputs
10908 	if (execution.model == ExecutionModelGLCompute && type.storage == StorageClassInput)
10909 	{
10910 		if (is_builtin)
10911 		{
10912 			switch (builtin)
10913 			{
10914 			case BuiltInNumSubgroups:
10915 			case BuiltInSubgroupId:
10916 			case BuiltInSubgroupLocalInvocationId: // FIXME: Should work in any stage
10917 			case BuiltInSubgroupSize: // FIXME: Should work in any stage
10918 				if (msl_options.emulate_subgroups)
10919 					break;
10920 				/* fallthrough */
10921 			case BuiltInGlobalInvocationId:
10922 			case BuiltInWorkgroupId:
10923 			case BuiltInNumWorkgroups:
10924 			case BuiltInLocalInvocationId:
10925 			case BuiltInLocalInvocationIndex:
10926 				return string(" [[") + builtin_qualifier(builtin) + "]]";
10927 
10928 			default:
10929 				return "";
10930 			}
10931 		}
10932 	}
10933 
10934 	return "";
10935 }
10936 
10937 // Returns the location decoration of the member with the specified index in the specified type.
10938 // If the location of the member has been explicitly set, that location is used. If not, this
10939 // function assumes the members are ordered in their location order, and simply returns the
10940 // index as the location.
get_member_location(uint32_t type_id,uint32_t index,uint32_t * comp) const10941 uint32_t CompilerMSL::get_member_location(uint32_t type_id, uint32_t index, uint32_t *comp) const
10942 {
10943 	if (comp)
10944 	{
10945 		if (has_member_decoration(type_id, index, DecorationComponent))
10946 			*comp = get_member_decoration(type_id, index, DecorationComponent);
10947 		else
10948 			*comp = k_unknown_component;
10949 	}
10950 
10951 	if (has_member_decoration(type_id, index, DecorationLocation))
10952 		return get_member_decoration(type_id, index, DecorationLocation);
10953 	else
10954 		return k_unknown_location;
10955 }
10956 
get_or_allocate_builtin_input_member_location(spv::BuiltIn builtin,uint32_t type_id,uint32_t index,uint32_t * comp)10957 uint32_t CompilerMSL::get_or_allocate_builtin_input_member_location(spv::BuiltIn builtin,
10958                                                                     uint32_t type_id, uint32_t index,
10959                                                                     uint32_t *comp)
10960 {
10961 	uint32_t loc = get_member_location(type_id, index, comp);
10962 	if (loc != k_unknown_location)
10963 		return loc;
10964 
10965 	if (comp)
10966 		*comp = k_unknown_component;
10967 
10968 	// Late allocation. Find a location which is unused by the application.
10969 	// This can happen for built-in inputs in tessellation which are mixed and matched with user inputs.
10970 	auto &mbr_type = get<SPIRType>(get<SPIRType>(type_id).member_types[index]);
10971 	uint32_t count = type_to_location_count(mbr_type);
10972 
10973 	loc = 0;
10974 
10975 	const auto location_range_in_use = [this](uint32_t location, uint32_t location_count) -> bool {
10976 		for (uint32_t i = 0; i < location_count; i++)
10977 			if (location_inputs_in_use.count(location + i) != 0)
10978 				return true;
10979 		return false;
10980 	};
10981 
10982 	while (location_range_in_use(loc, count))
10983 		loc++;
10984 
10985 	set_member_decoration(type_id, index, DecorationLocation, loc);
10986 
10987 	// Triangle tess level inputs are shared in one packed float4,
10988 	// mark both builtins as sharing one location.
10989 	if (get_execution_mode_bitset().get(ExecutionModeTriangles) &&
10990 	    (builtin == BuiltInTessLevelInner || builtin == BuiltInTessLevelOuter))
10991 	{
10992 		builtin_to_automatic_input_location[BuiltInTessLevelInner] = loc;
10993 		builtin_to_automatic_input_location[BuiltInTessLevelOuter] = loc;
10994 	}
10995 	else
10996 		builtin_to_automatic_input_location[builtin] = loc;
10997 
10998 	mark_location_as_used_by_shader(loc, mbr_type, StorageClassInput, true);
10999 	return loc;
11000 }
11001 
11002 // Returns the type declaration for a function, including the
11003 // entry type if the current function is the entry point function
func_type_decl(SPIRType & type)11004 string CompilerMSL::func_type_decl(SPIRType &type)
11005 {
11006 	// The regular function return type. If not processing the entry point function, that's all we need
11007 	string return_type = type_to_glsl(type) + type_to_array_glsl(type);
11008 	if (!processing_entry_point)
11009 		return return_type;
11010 
11011 	// If an outgoing interface block has been defined, and it should be returned, override the entry point return type
11012 	bool ep_should_return_output = !get_is_rasterization_disabled();
11013 	if (stage_out_var_id && ep_should_return_output)
11014 		return_type = type_to_glsl(get_stage_out_struct_type()) + type_to_array_glsl(type);
11015 
11016 	// Prepend a entry type, based on the execution model
11017 	string entry_type;
11018 	auto &execution = get_entry_point();
11019 	switch (execution.model)
11020 	{
11021 	case ExecutionModelVertex:
11022 		if (msl_options.vertex_for_tessellation && !msl_options.supports_msl_version(1, 2))
11023 			SPIRV_CROSS_THROW("Tessellation requires Metal 1.2.");
11024 		entry_type = msl_options.vertex_for_tessellation ? "kernel" : "vertex";
11025 		break;
11026 	case ExecutionModelTessellationEvaluation:
11027 		if (!msl_options.supports_msl_version(1, 2))
11028 			SPIRV_CROSS_THROW("Tessellation requires Metal 1.2.");
11029 		if (execution.flags.get(ExecutionModeIsolines))
11030 			SPIRV_CROSS_THROW("Metal does not support isoline tessellation.");
11031 		if (msl_options.is_ios())
11032 			entry_type =
11033 			    join("[[ patch(", execution.flags.get(ExecutionModeTriangles) ? "triangle" : "quad", ") ]] vertex");
11034 		else
11035 			entry_type = join("[[ patch(", execution.flags.get(ExecutionModeTriangles) ? "triangle" : "quad", ", ",
11036 			                  execution.output_vertices, ") ]] vertex");
11037 		break;
11038 	case ExecutionModelFragment:
11039 		entry_type = execution.flags.get(ExecutionModeEarlyFragmentTests) ||
11040 		                     execution.flags.get(ExecutionModePostDepthCoverage) ?
11041 		                 "[[ early_fragment_tests ]] fragment" :
11042 		                 "fragment";
11043 		break;
11044 	case ExecutionModelTessellationControl:
11045 		if (!msl_options.supports_msl_version(1, 2))
11046 			SPIRV_CROSS_THROW("Tessellation requires Metal 1.2.");
11047 		if (execution.flags.get(ExecutionModeIsolines))
11048 			SPIRV_CROSS_THROW("Metal does not support isoline tessellation.");
11049 		/* fallthrough */
11050 	case ExecutionModelGLCompute:
11051 	case ExecutionModelKernel:
11052 		entry_type = "kernel";
11053 		break;
11054 	default:
11055 		entry_type = "unknown";
11056 		break;
11057 	}
11058 
11059 	return entry_type + " " + return_type;
11060 }
11061 
11062 // In MSL, address space qualifiers are required for all pointer or reference variables
get_argument_address_space(const SPIRVariable & argument)11063 string CompilerMSL::get_argument_address_space(const SPIRVariable &argument)
11064 {
11065 	const auto &type = get<SPIRType>(argument.basetype);
11066 	return get_type_address_space(type, argument.self, true);
11067 }
11068 
get_type_address_space(const SPIRType & type,uint32_t id,bool argument)11069 string CompilerMSL::get_type_address_space(const SPIRType &type, uint32_t id, bool argument)
11070 {
11071 	// This can be called for variable pointer contexts as well, so be very careful about which method we choose.
11072 	Bitset flags;
11073 	auto *var = maybe_get<SPIRVariable>(id);
11074 	if (var && type.basetype == SPIRType::Struct &&
11075 	    (has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock)))
11076 		flags = get_buffer_block_flags(id);
11077 	else
11078 		flags = get_decoration_bitset(id);
11079 
11080 	const char *addr_space = nullptr;
11081 	switch (type.storage)
11082 	{
11083 	case StorageClassWorkgroup:
11084 		addr_space = "threadgroup";
11085 		break;
11086 
11087 	case StorageClassStorageBuffer:
11088 	{
11089 		// For arguments from variable pointers, we use the write count deduction, so
11090 		// we should not assume any constness here. Only for global SSBOs.
11091 		bool readonly = false;
11092 		if (!var || has_decoration(type.self, DecorationBlock))
11093 			readonly = flags.get(DecorationNonWritable);
11094 
11095 		addr_space = readonly ? "const device" : "device";
11096 		break;
11097 	}
11098 
11099 	case StorageClassUniform:
11100 	case StorageClassUniformConstant:
11101 	case StorageClassPushConstant:
11102 		if (type.basetype == SPIRType::Struct)
11103 		{
11104 			bool ssbo = has_decoration(type.self, DecorationBufferBlock);
11105 			if (ssbo)
11106 				addr_space = flags.get(DecorationNonWritable) ? "const device" : "device";
11107 			else
11108 				addr_space = "constant";
11109 		}
11110 		else if (!argument)
11111 		{
11112 			addr_space = "constant";
11113 		}
11114 		else if (type_is_msl_framebuffer_fetch(type))
11115 		{
11116 			// Subpass inputs are passed around by value.
11117 			addr_space = "";
11118 		}
11119 		break;
11120 
11121 	case StorageClassFunction:
11122 	case StorageClassGeneric:
11123 		break;
11124 
11125 	case StorageClassInput:
11126 		if (get_execution_model() == ExecutionModelTessellationControl && var &&
11127 		    var->basevariable == stage_in_ptr_var_id)
11128 			addr_space = msl_options.multi_patch_workgroup ? "constant" : "threadgroup";
11129 		if (get_execution_model() == ExecutionModelFragment && var && var->basevariable == stage_in_var_id)
11130 			addr_space = "thread";
11131 		break;
11132 
11133 	case StorageClassOutput:
11134 		if (capture_output_to_buffer)
11135 		{
11136 			if (var && type.storage == StorageClassOutput)
11137 			{
11138 				bool is_masked = is_stage_output_variable_masked(*var);
11139 
11140 				if (is_masked)
11141 				{
11142 					if (is_tessellation_shader())
11143 						addr_space = "threadgroup";
11144 					else
11145 						addr_space = "thread";
11146 				}
11147 				else if (variable_decl_is_remapped_storage(*var, StorageClassWorkgroup))
11148 					addr_space = "threadgroup";
11149 			}
11150 
11151 			if (!addr_space)
11152 				addr_space = "device";
11153 		}
11154 		break;
11155 
11156 	default:
11157 		break;
11158 	}
11159 
11160 	if (!addr_space)
11161 	{
11162 		// No address space for plain values.
11163 		addr_space = type.pointer || (argument && type.basetype == SPIRType::ControlPointArray) ? "thread" : "";
11164 	}
11165 
11166 	return join(flags.get(DecorationVolatile) || flags.get(DecorationCoherent) ? "volatile " : "", addr_space);
11167 }
11168 
to_restrict(uint32_t id,bool space)11169 const char *CompilerMSL::to_restrict(uint32_t id, bool space)
11170 {
11171 	// This can be called for variable pointer contexts as well, so be very careful about which method we choose.
11172 	Bitset flags;
11173 	if (ir.ids[id].get_type() == TypeVariable)
11174 	{
11175 		uint32_t type_id = expression_type_id(id);
11176 		auto &type = expression_type(id);
11177 		if (type.basetype == SPIRType::Struct &&
11178 		    (has_decoration(type_id, DecorationBlock) || has_decoration(type_id, DecorationBufferBlock)))
11179 			flags = get_buffer_block_flags(id);
11180 		else
11181 			flags = get_decoration_bitset(id);
11182 	}
11183 	else
11184 		flags = get_decoration_bitset(id);
11185 
11186 	return flags.get(DecorationRestrict) ? (space ? "restrict " : "restrict") : "";
11187 }
11188 
entry_point_arg_stage_in()11189 string CompilerMSL::entry_point_arg_stage_in()
11190 {
11191 	string decl;
11192 
11193 	if (get_execution_model() == ExecutionModelTessellationControl && msl_options.multi_patch_workgroup)
11194 		return decl;
11195 
11196 	// Stage-in structure
11197 	uint32_t stage_in_id;
11198 	if (get_execution_model() == ExecutionModelTessellationEvaluation)
11199 		stage_in_id = patch_stage_in_var_id;
11200 	else
11201 		stage_in_id = stage_in_var_id;
11202 
11203 	if (stage_in_id)
11204 	{
11205 		auto &var = get<SPIRVariable>(stage_in_id);
11206 		auto &type = get_variable_data_type(var);
11207 
11208 		add_resource_name(var.self);
11209 		decl = join(type_to_glsl(type), " ", to_name(var.self), " [[stage_in]]");
11210 	}
11211 
11212 	return decl;
11213 }
11214 
11215 // Returns true if this input builtin should be a direct parameter on a shader function parameter list,
11216 // and false for builtins that should be passed or calculated some other way.
is_direct_input_builtin(BuiltIn bi_type)11217 bool CompilerMSL::is_direct_input_builtin(BuiltIn bi_type)
11218 {
11219 	switch (bi_type)
11220 	{
11221 	// Vertex function in
11222 	case BuiltInVertexId:
11223 	case BuiltInVertexIndex:
11224 	case BuiltInBaseVertex:
11225 	case BuiltInInstanceId:
11226 	case BuiltInInstanceIndex:
11227 	case BuiltInBaseInstance:
11228 		return get_execution_model() != ExecutionModelVertex || !msl_options.vertex_for_tessellation;
11229 	// Tess. control function in
11230 	case BuiltInPosition:
11231 	case BuiltInPointSize:
11232 	case BuiltInClipDistance:
11233 	case BuiltInCullDistance:
11234 	case BuiltInPatchVertices:
11235 		return false;
11236 	case BuiltInInvocationId:
11237 	case BuiltInPrimitiveId:
11238 		return get_execution_model() != ExecutionModelTessellationControl || !msl_options.multi_patch_workgroup;
11239 	// Tess. evaluation function in
11240 	case BuiltInTessLevelInner:
11241 	case BuiltInTessLevelOuter:
11242 		return false;
11243 	// Fragment function in
11244 	case BuiltInSamplePosition:
11245 	case BuiltInHelperInvocation:
11246 	case BuiltInBaryCoordNV:
11247 	case BuiltInBaryCoordNoPerspNV:
11248 		return false;
11249 	case BuiltInViewIndex:
11250 		return get_execution_model() == ExecutionModelFragment && msl_options.multiview &&
11251 		       msl_options.multiview_layered_rendering;
11252 	// Compute function in
11253 	case BuiltInSubgroupId:
11254 	case BuiltInNumSubgroups:
11255 		return !msl_options.emulate_subgroups;
11256 	// Any stage function in
11257 	case BuiltInDeviceIndex:
11258 	case BuiltInSubgroupEqMask:
11259 	case BuiltInSubgroupGeMask:
11260 	case BuiltInSubgroupGtMask:
11261 	case BuiltInSubgroupLeMask:
11262 	case BuiltInSubgroupLtMask:
11263 		return false;
11264 	case BuiltInSubgroupSize:
11265 		if (msl_options.fixed_subgroup_size != 0)
11266 			return false;
11267 		/* fallthrough */
11268 	case BuiltInSubgroupLocalInvocationId:
11269 		return !msl_options.emulate_subgroups;
11270 	default:
11271 		return true;
11272 	}
11273 }
11274 
11275 // Returns true if this is a fragment shader that runs per sample, and false otherwise.
is_sample_rate() const11276 bool CompilerMSL::is_sample_rate() const
11277 {
11278 	auto &caps = get_declared_capabilities();
11279 	return get_execution_model() == ExecutionModelFragment &&
11280 	       (msl_options.force_sample_rate_shading ||
11281 	        std::find(caps.begin(), caps.end(), CapabilitySampleRateShading) != caps.end() ||
11282 	        (msl_options.use_framebuffer_fetch_subpasses && need_subpass_input));
11283 }
11284 
entry_point_args_builtin(string & ep_args)11285 void CompilerMSL::entry_point_args_builtin(string &ep_args)
11286 {
11287 	// Builtin variables
11288 	SmallVector<pair<SPIRVariable *, BuiltIn>, 8> active_builtins;
11289 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t var_id, SPIRVariable &var) {
11290 		if (var.storage != StorageClassInput)
11291 			return;
11292 
11293 		auto bi_type = BuiltIn(get_decoration(var_id, DecorationBuiltIn));
11294 
11295 		// Don't emit SamplePosition as a separate parameter. In the entry
11296 		// point, we get that by calling get_sample_position() on the sample ID.
11297 		if (is_builtin_variable(var) &&
11298 		    get_variable_data_type(var).basetype != SPIRType::Struct &&
11299 		    get_variable_data_type(var).basetype != SPIRType::ControlPointArray)
11300 		{
11301 			// If the builtin is not part of the active input builtin set, don't emit it.
11302 			// Relevant for multiple entry-point modules which might declare unused builtins.
11303 			if (!active_input_builtins.get(bi_type) || !interface_variable_exists_in_entry_point(var_id))
11304 				return;
11305 
11306 			// Remember this variable. We may need to correct its type.
11307 			active_builtins.push_back(make_pair(&var, bi_type));
11308 
11309 			if (is_direct_input_builtin(bi_type))
11310 			{
11311 				if (!ep_args.empty())
11312 					ep_args += ", ";
11313 
11314 				// Handle HLSL-style 0-based vertex/instance index.
11315 				builtin_declaration = true;
11316 				ep_args += builtin_type_decl(bi_type, var_id) + " " + to_expression(var_id);
11317 				ep_args += " [[" + builtin_qualifier(bi_type);
11318 				if (bi_type == BuiltInSampleMask && get_entry_point().flags.get(ExecutionModePostDepthCoverage))
11319 				{
11320 					if (!msl_options.supports_msl_version(2))
11321 						SPIRV_CROSS_THROW("Post-depth coverage requires MSL 2.0.");
11322 					if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 3))
11323 						SPIRV_CROSS_THROW("Post-depth coverage on Mac requires MSL 2.3.");
11324 					ep_args += ", post_depth_coverage";
11325 				}
11326 				ep_args += "]]";
11327 				builtin_declaration = false;
11328 			}
11329 		}
11330 
11331 		if (has_extended_decoration(var_id, SPIRVCrossDecorationBuiltInDispatchBase))
11332 		{
11333 			// This is a special implicit builtin, not corresponding to any SPIR-V builtin,
11334 			// which holds the base that was passed to vkCmdDispatchBase() or vkCmdDrawIndexed(). If it's present,
11335 			// assume we emitted it for a good reason.
11336 			assert(msl_options.supports_msl_version(1, 2));
11337 			if (!ep_args.empty())
11338 				ep_args += ", ";
11339 
11340 			ep_args += type_to_glsl(get_variable_data_type(var)) + " " + to_expression(var_id) + " [[grid_origin]]";
11341 		}
11342 
11343 		if (has_extended_decoration(var_id, SPIRVCrossDecorationBuiltInStageInputSize))
11344 		{
11345 			// This is another special implicit builtin, not corresponding to any SPIR-V builtin,
11346 			// which holds the number of vertices and instances to draw. If it's present,
11347 			// assume we emitted it for a good reason.
11348 			assert(msl_options.supports_msl_version(1, 2));
11349 			if (!ep_args.empty())
11350 				ep_args += ", ";
11351 
11352 			ep_args += type_to_glsl(get_variable_data_type(var)) + " " + to_expression(var_id) + " [[grid_size]]";
11353 		}
11354 	});
11355 
11356 	// Correct the types of all encountered active builtins. We couldn't do this before
11357 	// because ensure_correct_builtin_type() may increase the bound, which isn't allowed
11358 	// while iterating over IDs.
11359 	for (auto &var : active_builtins)
11360 		var.first->basetype = ensure_correct_builtin_type(var.first->basetype, var.second);
11361 
11362 	// Handle HLSL-style 0-based vertex/instance index.
11363 	if (needs_base_vertex_arg == TriState::Yes)
11364 		ep_args += built_in_func_arg(BuiltInBaseVertex, !ep_args.empty());
11365 
11366 	if (needs_base_instance_arg == TriState::Yes)
11367 		ep_args += built_in_func_arg(BuiltInBaseInstance, !ep_args.empty());
11368 
11369 	if (capture_output_to_buffer)
11370 	{
11371 		// Add parameters to hold the indirect draw parameters and the shader output. This has to be handled
11372 		// specially because it needs to be a pointer, not a reference.
11373 		if (stage_out_var_id)
11374 		{
11375 			if (!ep_args.empty())
11376 				ep_args += ", ";
11377 			ep_args += join("device ", type_to_glsl(get_stage_out_struct_type()), "* ", output_buffer_var_name,
11378 			                " [[buffer(", msl_options.shader_output_buffer_index, ")]]");
11379 		}
11380 
11381 		if (get_execution_model() == ExecutionModelTessellationControl)
11382 		{
11383 			if (!ep_args.empty())
11384 				ep_args += ", ";
11385 			ep_args +=
11386 			    join("constant uint* spvIndirectParams [[buffer(", msl_options.indirect_params_buffer_index, ")]]");
11387 		}
11388 		else if (stage_out_var_id &&
11389 		         !(get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation))
11390 		{
11391 			if (!ep_args.empty())
11392 				ep_args += ", ";
11393 			ep_args +=
11394 			    join("device uint* spvIndirectParams [[buffer(", msl_options.indirect_params_buffer_index, ")]]");
11395 		}
11396 
11397 		if (get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation &&
11398 		    (active_input_builtins.get(BuiltInVertexIndex) || active_input_builtins.get(BuiltInVertexId)) &&
11399 		    msl_options.vertex_index_type != Options::IndexType::None)
11400 		{
11401 			// Add the index buffer so we can set gl_VertexIndex correctly.
11402 			if (!ep_args.empty())
11403 				ep_args += ", ";
11404 			switch (msl_options.vertex_index_type)
11405 			{
11406 			case Options::IndexType::None:
11407 				break;
11408 			case Options::IndexType::UInt16:
11409 				ep_args += join("const device ushort* ", index_buffer_var_name, " [[buffer(",
11410 				                msl_options.shader_index_buffer_index, ")]]");
11411 				break;
11412 			case Options::IndexType::UInt32:
11413 				ep_args += join("const device uint* ", index_buffer_var_name, " [[buffer(",
11414 				                msl_options.shader_index_buffer_index, ")]]");
11415 				break;
11416 			}
11417 		}
11418 
11419 		// Tessellation control shaders get three additional parameters:
11420 		// a buffer to hold the per-patch data, a buffer to hold the per-patch
11421 		// tessellation levels, and a block of workgroup memory to hold the
11422 		// input control point data.
11423 		if (get_execution_model() == ExecutionModelTessellationControl)
11424 		{
11425 			if (patch_stage_out_var_id)
11426 			{
11427 				if (!ep_args.empty())
11428 					ep_args += ", ";
11429 				ep_args +=
11430 				    join("device ", type_to_glsl(get_patch_stage_out_struct_type()), "* ", patch_output_buffer_var_name,
11431 				         " [[buffer(", convert_to_string(msl_options.shader_patch_output_buffer_index), ")]]");
11432 			}
11433 			if (!ep_args.empty())
11434 				ep_args += ", ";
11435 			ep_args += join("device ", get_tess_factor_struct_name(), "* ", tess_factor_buffer_var_name, " [[buffer(",
11436 			                convert_to_string(msl_options.shader_tess_factor_buffer_index), ")]]");
11437 
11438 			// Initializer for tess factors must be handled specially since it's never declared as a normal variable.
11439 			uint32_t outer_factor_initializer_id = 0;
11440 			uint32_t inner_factor_initializer_id = 0;
11441 			ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
11442 				if (!has_decoration(var.self, DecorationBuiltIn) || var.storage != StorageClassOutput || !var.initializer)
11443 					return;
11444 
11445 				BuiltIn builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
11446 				if (builtin == BuiltInTessLevelInner)
11447 					inner_factor_initializer_id = var.initializer;
11448 				else if (builtin == BuiltInTessLevelOuter)
11449 					outer_factor_initializer_id = var.initializer;
11450 			});
11451 
11452 			const SPIRConstant *c = nullptr;
11453 
11454 			if (outer_factor_initializer_id && (c = maybe_get<SPIRConstant>(outer_factor_initializer_id)))
11455 			{
11456 				auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
11457 				entry_func.fixup_hooks_in.push_back([=]() {
11458 					uint32_t components = get_execution_mode_bitset().get(ExecutionModeTriangles) ? 3 : 4;
11459 					for (uint32_t i = 0; i < components; i++)
11460 					{
11461 						statement(builtin_to_glsl(BuiltInTessLevelOuter, StorageClassOutput), "[", i, "] = ",
11462 						          "half(", to_expression(c->subconstants[i]), ");");
11463 					}
11464 				});
11465 			}
11466 
11467 			if (inner_factor_initializer_id && (c = maybe_get<SPIRConstant>(inner_factor_initializer_id)))
11468 			{
11469 				auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
11470 				if (get_execution_mode_bitset().get(ExecutionModeTriangles))
11471 				{
11472 					entry_func.fixup_hooks_in.push_back([=]() {
11473 						statement(builtin_to_glsl(BuiltInTessLevelInner, StorageClassOutput), " = ", "half(",
11474 						          to_expression(c->subconstants[0]), ");");
11475 					});
11476 				}
11477 				else
11478 				{
11479 					entry_func.fixup_hooks_in.push_back([=]() {
11480 						for (uint32_t i = 0; i < 2; i++)
11481 						{
11482 							statement(builtin_to_glsl(BuiltInTessLevelInner, StorageClassOutput), "[", i, "] = ",
11483 							          "half(", to_expression(c->subconstants[i]), ");");
11484 						}
11485 					});
11486 				}
11487 			}
11488 
11489 			if (stage_in_var_id)
11490 			{
11491 				if (!ep_args.empty())
11492 					ep_args += ", ";
11493 				if (msl_options.multi_patch_workgroup)
11494 				{
11495 					ep_args += join("device ", type_to_glsl(get_stage_in_struct_type()), "* ", input_buffer_var_name,
11496 					                " [[buffer(", convert_to_string(msl_options.shader_input_buffer_index), ")]]");
11497 				}
11498 				else
11499 				{
11500 					ep_args += join("threadgroup ", type_to_glsl(get_stage_in_struct_type()), "* ", input_wg_var_name,
11501 					                " [[threadgroup(", convert_to_string(msl_options.shader_input_wg_index), ")]]");
11502 				}
11503 			}
11504 		}
11505 	}
11506 }
11507 
entry_point_args_argument_buffer(bool append_comma)11508 string CompilerMSL::entry_point_args_argument_buffer(bool append_comma)
11509 {
11510 	string ep_args = entry_point_arg_stage_in();
11511 	Bitset claimed_bindings;
11512 
11513 	for (uint32_t i = 0; i < kMaxArgumentBuffers; i++)
11514 	{
11515 		uint32_t id = argument_buffer_ids[i];
11516 		if (id == 0)
11517 			continue;
11518 
11519 		add_resource_name(id);
11520 		auto &var = get<SPIRVariable>(id);
11521 		auto &type = get_variable_data_type(var);
11522 
11523 		if (!ep_args.empty())
11524 			ep_args += ", ";
11525 
11526 		// Check if the argument buffer binding itself has been remapped.
11527 		uint32_t buffer_binding;
11528 		auto itr = resource_bindings.find({ get_entry_point().model, i, kArgumentBufferBinding });
11529 		if (itr != end(resource_bindings))
11530 		{
11531 			buffer_binding = itr->second.first.msl_buffer;
11532 			itr->second.second = true;
11533 		}
11534 		else
11535 		{
11536 			// As a fallback, directly map desc set <-> binding.
11537 			// If that was taken, take the next buffer binding.
11538 			if (claimed_bindings.get(i))
11539 				buffer_binding = next_metal_resource_index_buffer;
11540 			else
11541 				buffer_binding = i;
11542 		}
11543 
11544 		claimed_bindings.set(buffer_binding);
11545 
11546 		ep_args += get_argument_address_space(var) + " " + type_to_glsl(type) + "& " + to_restrict(id) + to_name(id);
11547 		ep_args += " [[buffer(" + convert_to_string(buffer_binding) + ")]]";
11548 
11549 		next_metal_resource_index_buffer = max(next_metal_resource_index_buffer, buffer_binding + 1);
11550 	}
11551 
11552 	entry_point_args_discrete_descriptors(ep_args);
11553 	entry_point_args_builtin(ep_args);
11554 
11555 	if (!ep_args.empty() && append_comma)
11556 		ep_args += ", ";
11557 
11558 	return ep_args;
11559 }
11560 
find_constexpr_sampler(uint32_t id) const11561 const MSLConstexprSampler *CompilerMSL::find_constexpr_sampler(uint32_t id) const
11562 {
11563 	// Try by ID.
11564 	{
11565 		auto itr = constexpr_samplers_by_id.find(id);
11566 		if (itr != end(constexpr_samplers_by_id))
11567 			return &itr->second;
11568 	}
11569 
11570 	// Try by binding.
11571 	{
11572 		uint32_t desc_set = get_decoration(id, DecorationDescriptorSet);
11573 		uint32_t binding = get_decoration(id, DecorationBinding);
11574 
11575 		auto itr = constexpr_samplers_by_binding.find({ desc_set, binding });
11576 		if (itr != end(constexpr_samplers_by_binding))
11577 			return &itr->second;
11578 	}
11579 
11580 	return nullptr;
11581 }
11582 
entry_point_args_discrete_descriptors(string & ep_args)11583 void CompilerMSL::entry_point_args_discrete_descriptors(string &ep_args)
11584 {
11585 	// Output resources, sorted by resource index & type
11586 	// We need to sort to work around a bug on macOS 10.13 with NVidia drivers where switching between shaders
11587 	// with different order of buffers can result in issues with buffer assignments inside the driver.
11588 	struct Resource
11589 	{
11590 		SPIRVariable *var;
11591 		string name;
11592 		SPIRType::BaseType basetype;
11593 		uint32_t index;
11594 		uint32_t plane;
11595 		uint32_t secondary_index;
11596 	};
11597 
11598 	SmallVector<Resource> resources;
11599 
11600 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t var_id, SPIRVariable &var) {
11601 		if ((var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant ||
11602 		     var.storage == StorageClassPushConstant || var.storage == StorageClassStorageBuffer) &&
11603 		    !is_hidden_variable(var))
11604 		{
11605 			auto &type = get_variable_data_type(var);
11606 
11607 			if (is_supported_argument_buffer_type(type) && var.storage != StorageClassPushConstant)
11608 			{
11609 				uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet);
11610 				if (descriptor_set_is_argument_buffer(desc_set))
11611 					return;
11612 			}
11613 
11614 			const MSLConstexprSampler *constexpr_sampler = nullptr;
11615 			if (type.basetype == SPIRType::SampledImage || type.basetype == SPIRType::Sampler)
11616 			{
11617 				constexpr_sampler = find_constexpr_sampler(var_id);
11618 				if (constexpr_sampler)
11619 				{
11620 					// Mark this ID as a constexpr sampler for later in case it came from set/bindings.
11621 					constexpr_samplers_by_id[var_id] = *constexpr_sampler;
11622 				}
11623 			}
11624 
11625 			// Emulate texture2D atomic operations
11626 			uint32_t secondary_index = 0;
11627 			if (atomic_image_vars.count(var.self))
11628 			{
11629 				secondary_index = get_metal_resource_index(var, SPIRType::AtomicCounter, 0);
11630 			}
11631 
11632 			if (type.basetype == SPIRType::SampledImage)
11633 			{
11634 				add_resource_name(var_id);
11635 
11636 				uint32_t plane_count = 1;
11637 				if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
11638 					plane_count = constexpr_sampler->planes;
11639 
11640 				for (uint32_t i = 0; i < plane_count; i++)
11641 					resources.push_back({ &var, to_name(var_id), SPIRType::Image,
11642 					                      get_metal_resource_index(var, SPIRType::Image, i), i, secondary_index });
11643 
11644 				if (type.image.dim != DimBuffer && !constexpr_sampler)
11645 				{
11646 					resources.push_back({ &var, to_sampler_expression(var_id), SPIRType::Sampler,
11647 					                      get_metal_resource_index(var, SPIRType::Sampler), 0, 0 });
11648 				}
11649 			}
11650 			else if (!constexpr_sampler)
11651 			{
11652 				// constexpr samplers are not declared as resources.
11653 				add_resource_name(var_id);
11654 				resources.push_back({ &var, to_name(var_id), type.basetype,
11655 				                      get_metal_resource_index(var, type.basetype), 0, secondary_index });
11656 			}
11657 		}
11658 	});
11659 
11660 	sort(resources.begin(), resources.end(), [](const Resource &lhs, const Resource &rhs) {
11661 		return tie(lhs.basetype, lhs.index) < tie(rhs.basetype, rhs.index);
11662 	});
11663 
11664 	for (auto &r : resources)
11665 	{
11666 		auto &var = *r.var;
11667 		auto &type = get_variable_data_type(var);
11668 
11669 		uint32_t var_id = var.self;
11670 
11671 		switch (r.basetype)
11672 		{
11673 		case SPIRType::Struct:
11674 		{
11675 			auto &m = ir.meta[type.self];
11676 			if (m.members.size() == 0)
11677 				break;
11678 			if (!type.array.empty())
11679 			{
11680 				if (type.array.size() > 1)
11681 					SPIRV_CROSS_THROW("Arrays of arrays of buffers are not supported.");
11682 
11683 				// Metal doesn't directly support this, so we must expand the
11684 				// array. We'll declare a local array to hold these elements
11685 				// later.
11686 				uint32_t array_size = to_array_size_literal(type);
11687 
11688 				if (array_size == 0)
11689 					SPIRV_CROSS_THROW("Unsized arrays of buffers are not supported in MSL.");
11690 
11691 				// Allow Metal to use the array<T> template to make arrays a value type
11692 				is_using_builtin_array = true;
11693 				buffer_arrays.push_back(var_id);
11694 				for (uint32_t i = 0; i < array_size; ++i)
11695 				{
11696 					if (!ep_args.empty())
11697 						ep_args += ", ";
11698 					ep_args += get_argument_address_space(var) + " " + type_to_glsl(type) + "* " + to_restrict(var_id) +
11699 					           r.name + "_" + convert_to_string(i);
11700 					ep_args += " [[buffer(" + convert_to_string(r.index + i) + ")";
11701 					if (interlocked_resources.count(var_id))
11702 						ep_args += ", raster_order_group(0)";
11703 					ep_args += "]]";
11704 				}
11705 				is_using_builtin_array = false;
11706 			}
11707 			else
11708 			{
11709 				if (!ep_args.empty())
11710 					ep_args += ", ";
11711 				ep_args +=
11712 				    get_argument_address_space(var) + " " + type_to_glsl(type) + "& " + to_restrict(var_id) + r.name;
11713 				ep_args += " [[buffer(" + convert_to_string(r.index) + ")";
11714 				if (interlocked_resources.count(var_id))
11715 					ep_args += ", raster_order_group(0)";
11716 				ep_args += "]]";
11717 			}
11718 			break;
11719 		}
11720 		case SPIRType::Sampler:
11721 			if (!ep_args.empty())
11722 				ep_args += ", ";
11723 			ep_args += sampler_type(type, var_id) + " " + r.name;
11724 			ep_args += " [[sampler(" + convert_to_string(r.index) + ")]]";
11725 			break;
11726 		case SPIRType::Image:
11727 		{
11728 			if (!ep_args.empty())
11729 				ep_args += ", ";
11730 
11731 			// Use Metal's native frame-buffer fetch API for subpass inputs.
11732 			const auto &basetype = get<SPIRType>(var.basetype);
11733 			if (!type_is_msl_framebuffer_fetch(basetype))
11734 			{
11735 				ep_args += image_type_glsl(type, var_id) + " " + r.name;
11736 				if (r.plane > 0)
11737 					ep_args += join(plane_name_suffix, r.plane);
11738 				ep_args += " [[texture(" + convert_to_string(r.index) + ")";
11739 				if (interlocked_resources.count(var_id))
11740 					ep_args += ", raster_order_group(0)";
11741 				ep_args += "]]";
11742 			}
11743 			else
11744 			{
11745 				if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 3))
11746 					SPIRV_CROSS_THROW("Framebuffer fetch on Mac is not supported before MSL 2.3.");
11747 				ep_args += image_type_glsl(type, var_id) + " " + r.name;
11748 				ep_args += " [[color(" + convert_to_string(r.index) + ")]]";
11749 			}
11750 
11751 			// Emulate texture2D atomic operations
11752 			if (atomic_image_vars.count(var.self))
11753 			{
11754 				ep_args += ", device atomic_" + type_to_glsl(get<SPIRType>(basetype.image.type), 0);
11755 				ep_args += "* " + r.name + "_atomic";
11756 				ep_args += " [[buffer(" + convert_to_string(r.secondary_index) + ")";
11757 				if (interlocked_resources.count(var_id))
11758 					ep_args += ", raster_order_group(0)";
11759 				ep_args += "]]";
11760 			}
11761 			break;
11762 		}
11763 		default:
11764 			if (!ep_args.empty())
11765 				ep_args += ", ";
11766 			if (!type.pointer)
11767 				ep_args += get_type_address_space(get<SPIRType>(var.basetype), var_id) + " " +
11768 				           type_to_glsl(type, var_id) + "& " + r.name;
11769 			else
11770 				ep_args += type_to_glsl(type, var_id) + " " + r.name;
11771 			ep_args += " [[buffer(" + convert_to_string(r.index) + ")";
11772 			if (interlocked_resources.count(var_id))
11773 				ep_args += ", raster_order_group(0)";
11774 			ep_args += "]]";
11775 			break;
11776 		}
11777 	}
11778 }
11779 
11780 // Returns a string containing a comma-delimited list of args for the entry point function
11781 // This is the "classic" method of MSL 1 when we don't have argument buffer support.
entry_point_args_classic(bool append_comma)11782 string CompilerMSL::entry_point_args_classic(bool append_comma)
11783 {
11784 	string ep_args = entry_point_arg_stage_in();
11785 	entry_point_args_discrete_descriptors(ep_args);
11786 	entry_point_args_builtin(ep_args);
11787 
11788 	if (!ep_args.empty() && append_comma)
11789 		ep_args += ", ";
11790 
11791 	return ep_args;
11792 }
11793 
fix_up_shader_inputs_outputs()11794 void CompilerMSL::fix_up_shader_inputs_outputs()
11795 {
11796 	auto &entry_func = this->get<SPIRFunction>(ir.default_entry_point);
11797 
11798 	// Emit a guard to ensure we don't execute beyond the last vertex.
11799 	// Vertex shaders shouldn't have the problems with barriers in non-uniform control flow that
11800 	// tessellation control shaders do, so early returns should be OK. We may need to revisit this
11801 	// if it ever becomes possible to use barriers from a vertex shader.
11802 	if (get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation)
11803 	{
11804 		entry_func.fixup_hooks_in.push_back([this]() {
11805 			statement("if (any(", to_expression(builtin_invocation_id_id),
11806 			          " >= ", to_expression(builtin_stage_input_size_id), "))");
11807 			statement("    return;");
11808 		});
11809 	}
11810 
11811 	// Look for sampled images and buffer. Add hooks to set up the swizzle constants or array lengths.
11812 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
11813 		auto &type = get_variable_data_type(var);
11814 		uint32_t var_id = var.self;
11815 		bool ssbo = has_decoration(type.self, DecorationBufferBlock);
11816 
11817 		if (var.storage == StorageClassUniformConstant && !is_hidden_variable(var))
11818 		{
11819 			if (msl_options.swizzle_texture_samples && has_sampled_images && is_sampled_image_type(type))
11820 			{
11821 				entry_func.fixup_hooks_in.push_back([this, &type, &var, var_id]() {
11822 					bool is_array_type = !type.array.empty();
11823 
11824 					uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet);
11825 					if (descriptor_set_is_argument_buffer(desc_set))
11826 					{
11827 						statement("constant uint", is_array_type ? "* " : "& ", to_swizzle_expression(var_id),
11828 						          is_array_type ? " = &" : " = ", to_name(argument_buffer_ids[desc_set]),
11829 						          ".spvSwizzleConstants", "[",
11830 						          convert_to_string(get_metal_resource_index(var, SPIRType::Image)), "];");
11831 					}
11832 					else
11833 					{
11834 						// If we have an array of images, we need to be able to index into it, so take a pointer instead.
11835 						statement("constant uint", is_array_type ? "* " : "& ", to_swizzle_expression(var_id),
11836 						          is_array_type ? " = &" : " = ", to_name(swizzle_buffer_id), "[",
11837 						          convert_to_string(get_metal_resource_index(var, SPIRType::Image)), "];");
11838 					}
11839 				});
11840 			}
11841 		}
11842 		else if ((var.storage == StorageClassStorageBuffer || (var.storage == StorageClassUniform && ssbo)) &&
11843 		         !is_hidden_variable(var))
11844 		{
11845 			if (buffers_requiring_array_length.count(var.self))
11846 			{
11847 				entry_func.fixup_hooks_in.push_back([this, &type, &var, var_id]() {
11848 					bool is_array_type = !type.array.empty();
11849 
11850 					uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet);
11851 					if (descriptor_set_is_argument_buffer(desc_set))
11852 					{
11853 						statement("constant uint", is_array_type ? "* " : "& ", to_buffer_size_expression(var_id),
11854 						          is_array_type ? " = &" : " = ", to_name(argument_buffer_ids[desc_set]),
11855 						          ".spvBufferSizeConstants", "[",
11856 						          convert_to_string(get_metal_resource_index(var, SPIRType::Image)), "];");
11857 					}
11858 					else
11859 					{
11860 						// If we have an array of images, we need to be able to index into it, so take a pointer instead.
11861 						statement("constant uint", is_array_type ? "* " : "& ", to_buffer_size_expression(var_id),
11862 						          is_array_type ? " = &" : " = ", to_name(buffer_size_buffer_id), "[",
11863 						          convert_to_string(get_metal_resource_index(var, type.basetype)), "];");
11864 					}
11865 				});
11866 			}
11867 		}
11868 	});
11869 
11870 	// Builtin variables
11871 	ir.for_each_typed_id<SPIRVariable>([this, &entry_func](uint32_t, SPIRVariable &var) {
11872 		uint32_t var_id = var.self;
11873 		BuiltIn bi_type = ir.meta[var_id].decoration.builtin_type;
11874 
11875 		if (var.storage != StorageClassInput && var.storage != StorageClassOutput)
11876 			return;
11877 		if (!interface_variable_exists_in_entry_point(var.self))
11878 			return;
11879 
11880 		if (var.storage == StorageClassInput && is_builtin_variable(var) && active_input_builtins.get(bi_type))
11881 		{
11882 			switch (bi_type)
11883 			{
11884 			case BuiltInSamplePosition:
11885 				entry_func.fixup_hooks_in.push_back([=]() {
11886 					statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = get_sample_position(",
11887 					          to_expression(builtin_sample_id_id), ");");
11888 				});
11889 				break;
11890 			case BuiltInFragCoord:
11891 				if (is_sample_rate())
11892 				{
11893 					entry_func.fixup_hooks_in.push_back([=]() {
11894 						statement(to_expression(var_id), ".xy += get_sample_position(",
11895 						          to_expression(builtin_sample_id_id), ") - 0.5;");
11896 					});
11897 				}
11898 				break;
11899 			case BuiltInHelperInvocation:
11900 				if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 3))
11901 					SPIRV_CROSS_THROW("simd_is_helper_thread() requires version 2.3 on iOS.");
11902 				else if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 1))
11903 					SPIRV_CROSS_THROW("simd_is_helper_thread() requires version 2.1 on macOS.");
11904 
11905 				entry_func.fixup_hooks_in.push_back([=]() {
11906 					statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = simd_is_helper_thread();");
11907 				});
11908 				break;
11909 			case BuiltInInvocationId:
11910 				// This is direct-mapped without multi-patch workgroups.
11911 				if (get_execution_model() != ExecutionModelTessellationControl || !msl_options.multi_patch_workgroup)
11912 					break;
11913 
11914 				entry_func.fixup_hooks_in.push_back([=]() {
11915 					statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
11916 					          to_expression(builtin_invocation_id_id), ".x % ", this->get_entry_point().output_vertices,
11917 					          ";");
11918 				});
11919 				break;
11920 			case BuiltInPrimitiveId:
11921 				// This is natively supported by fragment and tessellation evaluation shaders.
11922 				// In tessellation control shaders, this is direct-mapped without multi-patch workgroups.
11923 				if (get_execution_model() != ExecutionModelTessellationControl || !msl_options.multi_patch_workgroup)
11924 					break;
11925 
11926 				entry_func.fixup_hooks_in.push_back([=]() {
11927 					statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = min(",
11928 					          to_expression(builtin_invocation_id_id), ".x / ", this->get_entry_point().output_vertices,
11929 					          ", spvIndirectParams[1] - 1);");
11930 				});
11931 				break;
11932 			case BuiltInPatchVertices:
11933 				if (get_execution_model() == ExecutionModelTessellationEvaluation)
11934 					entry_func.fixup_hooks_in.push_back([=]() {
11935 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
11936 						          to_expression(patch_stage_in_var_id), ".gl_in.size();");
11937 					});
11938 				else
11939 					entry_func.fixup_hooks_in.push_back([=]() {
11940 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = spvIndirectParams[0];");
11941 					});
11942 				break;
11943 			case BuiltInTessCoord:
11944 				// Emit a fixup to account for the shifted domain. Don't do this for triangles;
11945 				// MoltenVK will just reverse the winding order instead.
11946 				if (msl_options.tess_domain_origin_lower_left && !get_entry_point().flags.get(ExecutionModeTriangles))
11947 				{
11948 					string tc = to_expression(var_id);
11949 					entry_func.fixup_hooks_in.push_back([=]() { statement(tc, ".y = 1.0 - ", tc, ".y;"); });
11950 				}
11951 				break;
11952 			case BuiltInSubgroupId:
11953 				if (!msl_options.emulate_subgroups)
11954 					break;
11955 				// For subgroup emulation, this is the same as the local invocation index.
11956 				entry_func.fixup_hooks_in.push_back([=]() {
11957 					statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
11958 					          to_expression(builtin_local_invocation_index_id), ";");
11959 				});
11960 				break;
11961 			case BuiltInNumSubgroups:
11962 				if (!msl_options.emulate_subgroups)
11963 					break;
11964 				// For subgroup emulation, this is the same as the workgroup size.
11965 				entry_func.fixup_hooks_in.push_back([=]() {
11966 					auto &type = expression_type(builtin_workgroup_size_id);
11967 					string size_expr = to_expression(builtin_workgroup_size_id);
11968 					if (type.vecsize >= 3)
11969 						size_expr = join(size_expr, ".x * ", size_expr, ".y * ", size_expr, ".z");
11970 					else if (type.vecsize == 2)
11971 						size_expr = join(size_expr, ".x * ", size_expr, ".y");
11972 					statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ", size_expr, ";");
11973 				});
11974 				break;
11975 			case BuiltInSubgroupLocalInvocationId:
11976 				if (!msl_options.emulate_subgroups)
11977 					break;
11978 				// For subgroup emulation, assume subgroups of size 1.
11979 				entry_func.fixup_hooks_in.push_back(
11980 				    [=]() { statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = 0;"); });
11981 				break;
11982 			case BuiltInSubgroupSize:
11983 				if (msl_options.emulate_subgroups)
11984 				{
11985 					// For subgroup emulation, assume subgroups of size 1.
11986 					entry_func.fixup_hooks_in.push_back(
11987 					    [=]() { statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = 1;"); });
11988 				}
11989 				else if (msl_options.fixed_subgroup_size != 0)
11990 				{
11991 					entry_func.fixup_hooks_in.push_back([=]() {
11992 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
11993 						          msl_options.fixed_subgroup_size, ";");
11994 					});
11995 				}
11996 				break;
11997 			case BuiltInSubgroupEqMask:
11998 				if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 2))
11999 					SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.2 on iOS.");
12000 				if (!msl_options.supports_msl_version(2, 1))
12001 					SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
12002 				entry_func.fixup_hooks_in.push_back([=]() {
12003 					if (msl_options.is_ios())
12004 					{
12005 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ", "uint4(1 << ",
12006 						          to_expression(builtin_subgroup_invocation_id_id), ", uint3(0));");
12007 					}
12008 					else
12009 					{
12010 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
12011 						          to_expression(builtin_subgroup_invocation_id_id), " >= 32 ? uint4(0, (1 << (",
12012 						          to_expression(builtin_subgroup_invocation_id_id), " - 32)), uint2(0)) : uint4(1 << ",
12013 						          to_expression(builtin_subgroup_invocation_id_id), ", uint3(0));");
12014 					}
12015 				});
12016 				break;
12017 			case BuiltInSubgroupGeMask:
12018 				if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 2))
12019 					SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.2 on iOS.");
12020 				if (!msl_options.supports_msl_version(2, 1))
12021 					SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
12022 				if (msl_options.fixed_subgroup_size != 0)
12023 					add_spv_func_and_recompile(SPVFuncImplSubgroupBallot);
12024 				entry_func.fixup_hooks_in.push_back([=]() {
12025 					// Case where index < 32, size < 32:
12026 					// mask0 = bfi(0, 0xFFFFFFFF, index, size - index);
12027 					// mask1 = bfi(0, 0xFFFFFFFF, 0, 0); // Gives 0
12028 					// Case where index < 32 but size >= 32:
12029 					// mask0 = bfi(0, 0xFFFFFFFF, index, 32 - index);
12030 					// mask1 = bfi(0, 0xFFFFFFFF, 0, size - 32);
12031 					// Case where index >= 32:
12032 					// mask0 = bfi(0, 0xFFFFFFFF, 32, 0); // Gives 0
12033 					// mask1 = bfi(0, 0xFFFFFFFF, index - 32, size - index);
12034 					// This is expressed without branches to avoid divergent
12035 					// control flow--hence the complicated min/max expressions.
12036 					// This is further complicated by the fact that if you attempt
12037 					// to bfi/bfe out-of-bounds on Metal, undefined behavior is the
12038 					// result.
12039 					if (msl_options.fixed_subgroup_size > 32)
12040 					{
12041 						// Don't use the subgroup size variable with fixed subgroup sizes,
12042 						// since the variables could be defined in the wrong order.
12043 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
12044 						          " = uint4(insert_bits(0u, 0xFFFFFFFF, min(",
12045 						          to_expression(builtin_subgroup_invocation_id_id), ", 32u), (uint)max(32 - (int)",
12046 						          to_expression(builtin_subgroup_invocation_id_id),
12047 						          ", 0)), insert_bits(0u, 0xFFFFFFFF,"
12048 						          " (uint)max((int)",
12049 						          to_expression(builtin_subgroup_invocation_id_id), " - 32, 0), ",
12050 						          msl_options.fixed_subgroup_size, " - max(",
12051 						          to_expression(builtin_subgroup_invocation_id_id),
12052 						          ", 32u)), uint2(0));");
12053 					}
12054 					else if (msl_options.fixed_subgroup_size != 0)
12055 					{
12056 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
12057 						          " = uint4(insert_bits(0u, 0xFFFFFFFF, ",
12058 						          to_expression(builtin_subgroup_invocation_id_id), ", ",
12059 						          msl_options.fixed_subgroup_size, " - ",
12060 						          to_expression(builtin_subgroup_invocation_id_id),
12061 						          "), uint3(0));");
12062 					}
12063 					else if (msl_options.is_ios())
12064 					{
12065 						// On iOS, the SIMD-group size will currently never exceed 32.
12066 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
12067 						          " = uint4(insert_bits(0u, 0xFFFFFFFF, ",
12068 						          to_expression(builtin_subgroup_invocation_id_id), ", ",
12069 						          to_expression(builtin_subgroup_size_id), " - ",
12070 						          to_expression(builtin_subgroup_invocation_id_id), "), uint3(0));");
12071 					}
12072 					else
12073 					{
12074 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
12075 						          " = uint4(insert_bits(0u, 0xFFFFFFFF, min(",
12076 						          to_expression(builtin_subgroup_invocation_id_id), ", 32u), (uint)max(min((int)",
12077 						          to_expression(builtin_subgroup_size_id), ", 32) - (int)",
12078 						          to_expression(builtin_subgroup_invocation_id_id),
12079 						          ", 0)), insert_bits(0u, 0xFFFFFFFF, (uint)max((int)",
12080 						          to_expression(builtin_subgroup_invocation_id_id), " - 32, 0), (uint)max((int)",
12081 						          to_expression(builtin_subgroup_size_id), " - (int)max(",
12082 						          to_expression(builtin_subgroup_invocation_id_id), ", 32u), 0)), uint2(0));");
12083 					}
12084 				});
12085 				break;
12086 			case BuiltInSubgroupGtMask:
12087 				if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 2))
12088 					SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.2 on iOS.");
12089 				if (!msl_options.supports_msl_version(2, 1))
12090 					SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
12091 				add_spv_func_and_recompile(SPVFuncImplSubgroupBallot);
12092 				entry_func.fixup_hooks_in.push_back([=]() {
12093 					// The same logic applies here, except now the index is one
12094 					// more than the subgroup invocation ID.
12095 					if (msl_options.fixed_subgroup_size > 32)
12096 					{
12097 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
12098 						          " = uint4(insert_bits(0u, 0xFFFFFFFF, min(",
12099 						          to_expression(builtin_subgroup_invocation_id_id), " + 1, 32u), (uint)max(32 - (int)",
12100 						          to_expression(builtin_subgroup_invocation_id_id),
12101 						          " - 1, 0)), insert_bits(0u, 0xFFFFFFFF, (uint)max((int)",
12102 						          to_expression(builtin_subgroup_invocation_id_id), " + 1 - 32, 0), ",
12103 						          msl_options.fixed_subgroup_size, " - max(",
12104 						          to_expression(builtin_subgroup_invocation_id_id),
12105 						          " + 1, 32u)), uint2(0));");
12106 					}
12107 					else if (msl_options.fixed_subgroup_size != 0)
12108 					{
12109 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
12110 						          " = uint4(insert_bits(0u, 0xFFFFFFFF, ",
12111 						          to_expression(builtin_subgroup_invocation_id_id), " + 1, ",
12112 						          msl_options.fixed_subgroup_size, " - ",
12113 						          to_expression(builtin_subgroup_invocation_id_id),
12114 						          " - 1), uint3(0));");
12115 					}
12116 					else if (msl_options.is_ios())
12117 					{
12118 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
12119 						          " = uint4(insert_bits(0u, 0xFFFFFFFF, ",
12120 						          to_expression(builtin_subgroup_invocation_id_id), " + 1, ",
12121 						          to_expression(builtin_subgroup_size_id), " - ",
12122 						          to_expression(builtin_subgroup_invocation_id_id), " - 1), uint3(0));");
12123 					}
12124 					else
12125 					{
12126 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
12127 						          " = uint4(insert_bits(0u, 0xFFFFFFFF, min(",
12128 						          to_expression(builtin_subgroup_invocation_id_id), " + 1, 32u), (uint)max(min((int)",
12129 						          to_expression(builtin_subgroup_size_id), ", 32) - (int)",
12130 						          to_expression(builtin_subgroup_invocation_id_id),
12131 						          " - 1, 0)), insert_bits(0u, 0xFFFFFFFF, (uint)max((int)",
12132 						          to_expression(builtin_subgroup_invocation_id_id), " + 1 - 32, 0), (uint)max((int)",
12133 						          to_expression(builtin_subgroup_size_id), " - (int)max(",
12134 						          to_expression(builtin_subgroup_invocation_id_id), " + 1, 32u), 0)), uint2(0));");
12135 					}
12136 				});
12137 				break;
12138 			case BuiltInSubgroupLeMask:
12139 				if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 2))
12140 					SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.2 on iOS.");
12141 				if (!msl_options.supports_msl_version(2, 1))
12142 					SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
12143 				add_spv_func_and_recompile(SPVFuncImplSubgroupBallot);
12144 				entry_func.fixup_hooks_in.push_back([=]() {
12145 					if (msl_options.is_ios())
12146 					{
12147 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
12148 						          " = uint4(extract_bits(0xFFFFFFFF, 0, ",
12149 						          to_expression(builtin_subgroup_invocation_id_id), " + 1), uint3(0));");
12150 					}
12151 					else
12152 					{
12153 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
12154 						          " = uint4(extract_bits(0xFFFFFFFF, 0, min(",
12155 						          to_expression(builtin_subgroup_invocation_id_id),
12156 						          " + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)",
12157 						          to_expression(builtin_subgroup_invocation_id_id), " + 1 - 32, 0)), uint2(0));");
12158 					}
12159 				});
12160 				break;
12161 			case BuiltInSubgroupLtMask:
12162 				if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 2))
12163 					SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.2 on iOS.");
12164 				if (!msl_options.supports_msl_version(2, 1))
12165 					SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
12166 				add_spv_func_and_recompile(SPVFuncImplSubgroupBallot);
12167 				entry_func.fixup_hooks_in.push_back([=]() {
12168 					if (msl_options.is_ios())
12169 					{
12170 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
12171 						          " = uint4(extract_bits(0xFFFFFFFF, 0, ",
12172 						          to_expression(builtin_subgroup_invocation_id_id), "), uint3(0));");
12173 					}
12174 					else
12175 					{
12176 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
12177 						          " = uint4(extract_bits(0xFFFFFFFF, 0, min(",
12178 						          to_expression(builtin_subgroup_invocation_id_id),
12179 						          ", 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)",
12180 						          to_expression(builtin_subgroup_invocation_id_id), " - 32, 0)), uint2(0));");
12181 					}
12182 				});
12183 				break;
12184 			case BuiltInViewIndex:
12185 				if (!msl_options.multiview)
12186 				{
12187 					// According to the Vulkan spec, when not running under a multiview
12188 					// render pass, ViewIndex is 0.
12189 					entry_func.fixup_hooks_in.push_back([=]() {
12190 						statement("const ", builtin_type_decl(bi_type), " ", to_expression(var_id), " = 0;");
12191 					});
12192 				}
12193 				else if (msl_options.view_index_from_device_index)
12194 				{
12195 					// In this case, we take the view index from that of the device we're running on.
12196 					entry_func.fixup_hooks_in.push_back([=]() {
12197 						statement("const ", builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
12198 						          msl_options.device_index, ";");
12199 					});
12200 					// We actually don't want to set the render_target_array_index here.
12201 					// Since every physical device is rendering a different view,
12202 					// there's no need for layered rendering here.
12203 				}
12204 				else if (!msl_options.multiview_layered_rendering)
12205 				{
12206 					// In this case, the views are rendered one at a time. The view index, then,
12207 					// is just the first part of the "view mask".
12208 					entry_func.fixup_hooks_in.push_back([=]() {
12209 						statement("const ", builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
12210 						          to_expression(view_mask_buffer_id), "[0];");
12211 					});
12212 				}
12213 				else if (get_execution_model() == ExecutionModelFragment)
12214 				{
12215 					// Because we adjusted the view index in the vertex shader, we have to
12216 					// adjust it back here.
12217 					entry_func.fixup_hooks_in.push_back([=]() {
12218 						statement(to_expression(var_id), " += ", to_expression(view_mask_buffer_id), "[0];");
12219 					});
12220 				}
12221 				else if (get_execution_model() == ExecutionModelVertex)
12222 				{
12223 					// Metal provides no special support for multiview, so we smuggle
12224 					// the view index in the instance index.
12225 					entry_func.fixup_hooks_in.push_back([=]() {
12226 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
12227 						          to_expression(view_mask_buffer_id), "[0] + (", to_expression(builtin_instance_idx_id),
12228 						          " - ", to_expression(builtin_base_instance_id), ") % ",
12229 						          to_expression(view_mask_buffer_id), "[1];");
12230 						statement(to_expression(builtin_instance_idx_id), " = (",
12231 						          to_expression(builtin_instance_idx_id), " - ",
12232 						          to_expression(builtin_base_instance_id), ") / ", to_expression(view_mask_buffer_id),
12233 						          "[1] + ", to_expression(builtin_base_instance_id), ";");
12234 					});
12235 					// In addition to setting the variable itself, we also need to
12236 					// set the render_target_array_index with it on output. We have to
12237 					// offset this by the base view index, because Metal isn't in on
12238 					// our little game here.
12239 					entry_func.fixup_hooks_out.push_back([=]() {
12240 						statement(to_expression(builtin_layer_id), " = ", to_expression(var_id), " - ",
12241 						          to_expression(view_mask_buffer_id), "[0];");
12242 					});
12243 				}
12244 				break;
12245 			case BuiltInDeviceIndex:
12246 				// Metal pipelines belong to the devices which create them, so we'll
12247 				// need to create a MTLPipelineState for every MTLDevice in a grouped
12248 				// VkDevice. We can assume, then, that the device index is constant.
12249 				entry_func.fixup_hooks_in.push_back([=]() {
12250 					statement("const ", builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
12251 					          msl_options.device_index, ";");
12252 				});
12253 				break;
12254 			case BuiltInWorkgroupId:
12255 				if (!msl_options.dispatch_base || !active_input_builtins.get(BuiltInWorkgroupId))
12256 					break;
12257 
12258 				// The vkCmdDispatchBase() command lets the client set the base value
12259 				// of WorkgroupId. Metal has no direct equivalent; we must make this
12260 				// adjustment ourselves.
12261 				entry_func.fixup_hooks_in.push_back([=]() {
12262 					statement(to_expression(var_id), " += ", to_dereferenced_expression(builtin_dispatch_base_id), ";");
12263 				});
12264 				break;
12265 			case BuiltInGlobalInvocationId:
12266 				if (!msl_options.dispatch_base || !active_input_builtins.get(BuiltInGlobalInvocationId))
12267 					break;
12268 
12269 				// GlobalInvocationId is defined as LocalInvocationId + WorkgroupId * WorkgroupSize.
12270 				// This needs to be adjusted too.
12271 				entry_func.fixup_hooks_in.push_back([=]() {
12272 					auto &execution = this->get_entry_point();
12273 					uint32_t workgroup_size_id = execution.workgroup_size.constant;
12274 					if (workgroup_size_id)
12275 						statement(to_expression(var_id), " += ", to_dereferenced_expression(builtin_dispatch_base_id),
12276 						          " * ", to_expression(workgroup_size_id), ";");
12277 					else
12278 						statement(to_expression(var_id), " += ", to_dereferenced_expression(builtin_dispatch_base_id),
12279 						          " * uint3(", execution.workgroup_size.x, ", ", execution.workgroup_size.y, ", ",
12280 						          execution.workgroup_size.z, ");");
12281 				});
12282 				break;
12283 			case BuiltInVertexId:
12284 			case BuiltInVertexIndex:
12285 				// This is direct-mapped normally.
12286 				if (!msl_options.vertex_for_tessellation)
12287 					break;
12288 
12289 				entry_func.fixup_hooks_in.push_back([=]() {
12290 					builtin_declaration = true;
12291 					switch (msl_options.vertex_index_type)
12292 					{
12293 					case Options::IndexType::None:
12294 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
12295 						          to_expression(builtin_invocation_id_id), ".x + ",
12296 						          to_expression(builtin_dispatch_base_id), ".x;");
12297 						break;
12298 					case Options::IndexType::UInt16:
12299 					case Options::IndexType::UInt32:
12300 						statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ", index_buffer_var_name,
12301 						          "[", to_expression(builtin_invocation_id_id), ".x] + ",
12302 						          to_expression(builtin_dispatch_base_id), ".x;");
12303 						break;
12304 					}
12305 					builtin_declaration = false;
12306 				});
12307 				break;
12308 			case BuiltInBaseVertex:
12309 				// This is direct-mapped normally.
12310 				if (!msl_options.vertex_for_tessellation)
12311 					break;
12312 
12313 				entry_func.fixup_hooks_in.push_back([=]() {
12314 					statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
12315 					          to_expression(builtin_dispatch_base_id), ".x;");
12316 				});
12317 				break;
12318 			case BuiltInInstanceId:
12319 			case BuiltInInstanceIndex:
12320 				// This is direct-mapped normally.
12321 				if (!msl_options.vertex_for_tessellation)
12322 					break;
12323 
12324 				entry_func.fixup_hooks_in.push_back([=]() {
12325 					builtin_declaration = true;
12326 					statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
12327 					          to_expression(builtin_invocation_id_id), ".y + ", to_expression(builtin_dispatch_base_id),
12328 					          ".y;");
12329 					builtin_declaration = false;
12330 				});
12331 				break;
12332 			case BuiltInBaseInstance:
12333 				// This is direct-mapped normally.
12334 				if (!msl_options.vertex_for_tessellation)
12335 					break;
12336 
12337 				entry_func.fixup_hooks_in.push_back([=]() {
12338 					statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
12339 					          to_expression(builtin_dispatch_base_id), ".y;");
12340 				});
12341 				break;
12342 			default:
12343 				break;
12344 			}
12345 		}
12346 		else if (var.storage == StorageClassOutput && get_execution_model() == ExecutionModelFragment &&
12347 				 is_builtin_variable(var) && active_output_builtins.get(bi_type) &&
12348 				 bi_type == BuiltInSampleMask && has_additional_fixed_sample_mask())
12349 		{
12350 			// If the additional fixed sample mask was set, we need to adjust the sample_mask
12351 			// output to reflect that. If the shader outputs the sample_mask itself too, we need
12352 			// to AND the two masks to get the final one.
12353 			string op_str = does_shader_write_sample_mask ? " &= " : " = ";
12354 			entry_func.fixup_hooks_out.push_back([=]() {
12355 				statement(to_expression(builtin_sample_mask_id), op_str, additional_fixed_sample_mask_str(), ";");
12356 			});
12357 		}
12358 	});
12359 }
12360 
12361 // Returns the Metal index of the resource of the specified type as used by the specified variable.
get_metal_resource_index(SPIRVariable & var,SPIRType::BaseType basetype,uint32_t plane)12362 uint32_t CompilerMSL::get_metal_resource_index(SPIRVariable &var, SPIRType::BaseType basetype, uint32_t plane)
12363 {
12364 	auto &execution = get_entry_point();
12365 	auto &var_dec = ir.meta[var.self].decoration;
12366 	auto &var_type = get<SPIRType>(var.basetype);
12367 	uint32_t var_desc_set = (var.storage == StorageClassPushConstant) ? kPushConstDescSet : var_dec.set;
12368 	uint32_t var_binding = (var.storage == StorageClassPushConstant) ? kPushConstBinding : var_dec.binding;
12369 
12370 	// If a matching binding has been specified, find and use it.
12371 	auto itr = resource_bindings.find({ execution.model, var_desc_set, var_binding });
12372 
12373 	// Atomic helper buffers for image atomics need to use secondary bindings as well.
12374 	bool use_secondary_binding = (var_type.basetype == SPIRType::SampledImage && basetype == SPIRType::Sampler) ||
12375 	                             basetype == SPIRType::AtomicCounter;
12376 
12377 	auto resource_decoration =
12378 	    use_secondary_binding ? SPIRVCrossDecorationResourceIndexSecondary : SPIRVCrossDecorationResourceIndexPrimary;
12379 
12380 	if (plane == 1)
12381 		resource_decoration = SPIRVCrossDecorationResourceIndexTertiary;
12382 	if (plane == 2)
12383 		resource_decoration = SPIRVCrossDecorationResourceIndexQuaternary;
12384 
12385 	if (itr != end(resource_bindings))
12386 	{
12387 		auto &remap = itr->second;
12388 		remap.second = true;
12389 		switch (basetype)
12390 		{
12391 		case SPIRType::Image:
12392 			set_extended_decoration(var.self, resource_decoration, remap.first.msl_texture + plane);
12393 			return remap.first.msl_texture + plane;
12394 		case SPIRType::Sampler:
12395 			set_extended_decoration(var.self, resource_decoration, remap.first.msl_sampler);
12396 			return remap.first.msl_sampler;
12397 		default:
12398 			set_extended_decoration(var.self, resource_decoration, remap.first.msl_buffer);
12399 			return remap.first.msl_buffer;
12400 		}
12401 	}
12402 
12403 	// If we have already allocated an index, keep using it.
12404 	if (has_extended_decoration(var.self, resource_decoration))
12405 		return get_extended_decoration(var.self, resource_decoration);
12406 
12407 	auto &type = get<SPIRType>(var.basetype);
12408 
12409 	if (type_is_msl_framebuffer_fetch(type))
12410 	{
12411 		// Frame-buffer fetch gets its fallback resource index from the input attachment index,
12412 		// which is then treated as color index.
12413 		return get_decoration(var.self, DecorationInputAttachmentIndex);
12414 	}
12415 	else if (msl_options.enable_decoration_binding)
12416 	{
12417 		// Allow user to enable decoration binding.
12418 		// If there is no explicit mapping of bindings to MSL, use the declared binding as a fallback.
12419 		if (has_decoration(var.self, DecorationBinding))
12420 		{
12421 			var_binding = get_decoration(var.self, DecorationBinding);
12422 			// Avoid emitting sentinel bindings.
12423 			if (var_binding < 0x80000000u)
12424 				return var_binding;
12425 		}
12426 	}
12427 
12428 	// If we did not explicitly remap, allocate bindings on demand.
12429 	// We cannot reliably use Binding decorations since SPIR-V and MSL's binding models are very different.
12430 
12431 	bool allocate_argument_buffer_ids = false;
12432 
12433 	if (var.storage != StorageClassPushConstant)
12434 		allocate_argument_buffer_ids = descriptor_set_is_argument_buffer(var_desc_set);
12435 
12436 	uint32_t binding_stride = 1;
12437 	for (uint32_t i = 0; i < uint32_t(type.array.size()); i++)
12438 		binding_stride *= to_array_size_literal(type, i);
12439 
12440 	assert(binding_stride != 0);
12441 
12442 	// If a binding has not been specified, revert to incrementing resource indices.
12443 	uint32_t resource_index;
12444 
12445 	if (allocate_argument_buffer_ids)
12446 	{
12447 		// Allocate from a flat ID binding space.
12448 		resource_index = next_metal_resource_ids[var_desc_set];
12449 		next_metal_resource_ids[var_desc_set] += binding_stride;
12450 	}
12451 	else
12452 	{
12453 		// Allocate from plain bindings which are allocated per resource type.
12454 		switch (basetype)
12455 		{
12456 		case SPIRType::Image:
12457 			resource_index = next_metal_resource_index_texture;
12458 			next_metal_resource_index_texture += binding_stride;
12459 			break;
12460 		case SPIRType::Sampler:
12461 			resource_index = next_metal_resource_index_sampler;
12462 			next_metal_resource_index_sampler += binding_stride;
12463 			break;
12464 		default:
12465 			resource_index = next_metal_resource_index_buffer;
12466 			next_metal_resource_index_buffer += binding_stride;
12467 			break;
12468 		}
12469 	}
12470 
12471 	set_extended_decoration(var.self, resource_decoration, resource_index);
12472 	return resource_index;
12473 }
12474 
type_is_msl_framebuffer_fetch(const SPIRType & type) const12475 bool CompilerMSL::type_is_msl_framebuffer_fetch(const SPIRType &type) const
12476 {
12477 	return type.basetype == SPIRType::Image && type.image.dim == DimSubpassData &&
12478 	       msl_options.use_framebuffer_fetch_subpasses;
12479 }
12480 
type_is_pointer(const SPIRType & type) const12481 bool CompilerMSL::type_is_pointer(const SPIRType &type) const
12482 {
12483 	if (!type.pointer)
12484 		return false;
12485 	auto &parent_type = get<SPIRType>(type.parent_type);
12486 	// Safeguards when we forget to set pointer_depth (there is an assert for it in type_to_glsl),
12487 	// but the extra check shouldn't hurt.
12488 	return (type.pointer_depth > parent_type.pointer_depth) || !parent_type.pointer;
12489 }
12490 
type_is_pointer_to_pointer(const SPIRType & type) const12491 bool CompilerMSL::type_is_pointer_to_pointer(const SPIRType &type) const
12492 {
12493 	if (!type.pointer)
12494 		return false;
12495 	auto &parent_type = get<SPIRType>(type.parent_type);
12496 	return type.pointer_depth > parent_type.pointer_depth && type_is_pointer(parent_type);
12497 }
12498 
argument_decl(const SPIRFunction::Parameter & arg)12499 string CompilerMSL::argument_decl(const SPIRFunction::Parameter &arg)
12500 {
12501 	auto &var = get<SPIRVariable>(arg.id);
12502 	auto &type = get_variable_data_type(var);
12503 	auto &var_type = get<SPIRType>(arg.type);
12504 	StorageClass type_storage = var_type.storage;
12505 	bool is_pointer = var_type.pointer;
12506 
12507 	// If we need to modify the name of the variable, make sure we use the original variable.
12508 	// Our alias is just a shadow variable.
12509 	uint32_t name_id = var.self;
12510 	if (arg.alias_global_variable && var.basevariable)
12511 		name_id = var.basevariable;
12512 
12513 	bool constref = !arg.alias_global_variable && is_pointer && arg.write_count == 0;
12514 	// Framebuffer fetch is plain value, const looks out of place, but it is not wrong.
12515 	if (type_is_msl_framebuffer_fetch(type))
12516 		constref = false;
12517 
12518 	bool type_is_image = type.basetype == SPIRType::Image || type.basetype == SPIRType::SampledImage ||
12519 	                     type.basetype == SPIRType::Sampler;
12520 
12521 	// Arrays of images/samplers in MSL are always const.
12522 	if (!type.array.empty() && type_is_image)
12523 		constref = true;
12524 
12525 	const char *cv_qualifier = constref ? "const " : "";
12526 	string decl;
12527 
12528 	// If this is a combined image-sampler for a 2D image with floating-point type,
12529 	// we emitted the 'spvDynamicImageSampler' type, and this is *not* an alias parameter
12530 	// for a global, then we need to emit a "dynamic" combined image-sampler.
12531 	// Unfortunately, this is necessary to properly support passing around
12532 	// combined image-samplers with Y'CbCr conversions on them.
12533 	bool is_dynamic_img_sampler = !arg.alias_global_variable && type.basetype == SPIRType::SampledImage &&
12534 	                              type.image.dim == Dim2D && type_is_floating_point(get<SPIRType>(type.image.type)) &&
12535 	                              spv_function_implementations.count(SPVFuncImplDynamicImageSampler);
12536 
12537 	// Allow Metal to use the array<T> template to make arrays a value type
12538 	string address_space = get_argument_address_space(var);
12539 	bool builtin = has_decoration(var.self, DecorationBuiltIn);
12540 	auto builtin_type = BuiltIn(get_decoration(arg.id, DecorationBuiltIn));
12541 
12542 	if (address_space == "threadgroup")
12543 		is_using_builtin_array = true;
12544 
12545 	if (var.basevariable && (var.basevariable == stage_in_ptr_var_id || var.basevariable == stage_out_ptr_var_id))
12546 		decl = join(cv_qualifier, type_to_glsl(type, arg.id));
12547 	else if (builtin)
12548 	{
12549 		// Only use templated array for Clip/Cull distance when feasible.
12550 		// In other scenarios, we need need to override array length for tess levels (if used as outputs),
12551 		// or we need to emit the expected type for builtins (uint vs int).
12552 		auto storage = get<SPIRType>(var.basetype).storage;
12553 
12554 		if (storage == StorageClassInput &&
12555 		    (builtin_type == BuiltInTessLevelInner || builtin_type == BuiltInTessLevelOuter))
12556 		{
12557 			is_using_builtin_array = false;
12558 		}
12559 		else if (builtin_type != BuiltInClipDistance && builtin_type != BuiltInCullDistance)
12560 		{
12561 			is_using_builtin_array = true;
12562 		}
12563 
12564 		if (storage == StorageClassOutput && variable_storage_requires_stage_io(storage) &&
12565 		    !is_stage_output_builtin_masked(builtin_type))
12566 			is_using_builtin_array = true;
12567 
12568 		if (is_using_builtin_array)
12569 			decl = join(cv_qualifier, builtin_type_decl(builtin_type, arg.id));
12570 		else
12571 			decl = join(cv_qualifier, type_to_glsl(type, arg.id));
12572 	}
12573 	else if ((type_storage == StorageClassUniform || type_storage == StorageClassStorageBuffer) && is_array(type))
12574 	{
12575 		is_using_builtin_array = true;
12576 		decl += join(cv_qualifier, type_to_glsl(type, arg.id), "*");
12577 	}
12578 	else if (is_dynamic_img_sampler)
12579 	{
12580 		decl = join(cv_qualifier, "spvDynamicImageSampler<", type_to_glsl(get<SPIRType>(type.image.type)), ">");
12581 		// Mark the variable so that we can handle passing it to another function.
12582 		set_extended_decoration(arg.id, SPIRVCrossDecorationDynamicImageSampler);
12583 	}
12584 	else
12585 	{
12586 		// The type is a pointer type we need to emit cv_qualifier late.
12587 		if (type_is_pointer(type))
12588 		{
12589 			decl = type_to_glsl(type, arg.id);
12590 			if (*cv_qualifier != '\0')
12591 				decl += join(" ", cv_qualifier);
12592 		}
12593 		else
12594 			decl = join(cv_qualifier, type_to_glsl(type, arg.id));
12595 	}
12596 
12597 	bool opaque_handle = type_storage == StorageClassUniformConstant;
12598 
12599 	if (!builtin && !opaque_handle && !is_pointer &&
12600 	    (type_storage == StorageClassFunction || type_storage == StorageClassGeneric))
12601 	{
12602 		// If the argument is a pure value and not an opaque type, we will pass by value.
12603 		if (msl_options.force_native_arrays && is_array(type))
12604 		{
12605 			// We are receiving an array by value. This is problematic.
12606 			// We cannot be sure of the target address space since we are supposed to receive a copy,
12607 			// but this is not possible with MSL without some extra work.
12608 			// We will have to assume we're getting a reference in thread address space.
12609 			// If we happen to get a reference in constant address space, the caller must emit a copy and pass that.
12610 			// Thread const therefore becomes the only logical choice, since we cannot "create" a constant array from
12611 			// non-constant arrays, but we can create thread const from constant.
12612 			decl = string("thread const ") + decl;
12613 			decl += " (&";
12614 			const char *restrict_kw = to_restrict(name_id);
12615 			if (*restrict_kw)
12616 			{
12617 				decl += " ";
12618 				decl += restrict_kw;
12619 			}
12620 			decl += to_expression(name_id);
12621 			decl += ")";
12622 			decl += type_to_array_glsl(type);
12623 		}
12624 		else
12625 		{
12626 			if (!address_space.empty())
12627 				decl = join(address_space, " ", decl);
12628 			decl += " ";
12629 			decl += to_expression(name_id);
12630 		}
12631 	}
12632 	else if (is_array(type) && !type_is_image)
12633 	{
12634 		// Arrays of images and samplers are special cased.
12635 		if (!address_space.empty())
12636 			decl = join(address_space, " ", decl);
12637 
12638 		if (msl_options.argument_buffers)
12639 		{
12640 			uint32_t desc_set = get_decoration(name_id, DecorationDescriptorSet);
12641 			if ((type_storage == StorageClassUniform || type_storage == StorageClassStorageBuffer) &&
12642 			    descriptor_set_is_argument_buffer(desc_set))
12643 			{
12644 				// An awkward case where we need to emit *more* address space declarations (yay!).
12645 				// An example is where we pass down an array of buffer pointers to leaf functions.
12646 				// It's a constant array containing pointers to constants.
12647 				// The pointer array is always constant however. E.g.
12648 				// device SSBO * constant (&array)[N].
12649 				// const device SSBO * constant (&array)[N].
12650 				// constant SSBO * constant (&array)[N].
12651 				// However, this only matters for argument buffers, since for MSL 1.0 style codegen,
12652 				// we emit the buffer array on stack instead, and that seems to work just fine apparently.
12653 
12654 				// If the argument was marked as being in device address space, any pointer to member would
12655 				// be const device, not constant.
12656 				if (argument_buffer_device_storage_mask & (1u << desc_set))
12657 					decl += " const device";
12658 				else
12659 					decl += " constant";
12660 			}
12661 		}
12662 
12663 		// Special case, need to override the array size here if we're using tess level as an argument.
12664 		if (get_execution_model() == ExecutionModelTessellationControl && builtin &&
12665 		    (builtin_type == BuiltInTessLevelInner || builtin_type == BuiltInTessLevelOuter))
12666 		{
12667 			uint32_t array_size = get_physical_tess_level_array_size(builtin_type);
12668 			if (array_size == 1)
12669 			{
12670 				decl += " &";
12671 				decl += to_expression(name_id);
12672 			}
12673 			else
12674 			{
12675 				decl += " (&";
12676 				decl += to_expression(name_id);
12677 				decl += ")";
12678 				decl += join("[", array_size, "]");
12679 			}
12680 		}
12681 		else
12682 		{
12683 			auto array_size_decl = type_to_array_glsl(type);
12684 			if (array_size_decl.empty())
12685 				decl += "& ";
12686 			else
12687 				decl += " (&";
12688 
12689 			const char *restrict_kw = to_restrict(name_id);
12690 			if (*restrict_kw)
12691 			{
12692 				decl += " ";
12693 				decl += restrict_kw;
12694 			}
12695 			decl += to_expression(name_id);
12696 
12697 			if (!array_size_decl.empty())
12698 			{
12699 				decl += ")";
12700 				decl += array_size_decl;
12701 			}
12702 		}
12703 	}
12704 	else if (!opaque_handle && (!pull_model_inputs.count(var.basevariable) || type.basetype == SPIRType::Struct))
12705 	{
12706 		// If this is going to be a reference to a variable pointer, the address space
12707 		// for the reference has to go before the '&', but after the '*'.
12708 		if (!address_space.empty())
12709 		{
12710 			if (type_is_pointer(type))
12711 			{
12712 				if (*cv_qualifier == '\0')
12713 					decl += ' ';
12714 				decl += join(address_space, " ");
12715 			}
12716 			else
12717 				decl = join(address_space, " ", decl);
12718 		}
12719 		decl += "&";
12720 		decl += " ";
12721 		decl += to_restrict(name_id);
12722 		decl += to_expression(name_id);
12723 	}
12724 	else
12725 	{
12726 		if (!address_space.empty())
12727 			decl = join(address_space, " ", decl);
12728 		decl += " ";
12729 		decl += to_expression(name_id);
12730 	}
12731 
12732 	// Emulate texture2D atomic operations
12733 	auto *backing_var = maybe_get_backing_variable(name_id);
12734 	if (backing_var && atomic_image_vars.count(backing_var->self))
12735 	{
12736 		decl += ", device atomic_" + type_to_glsl(get<SPIRType>(var_type.image.type), 0);
12737 		decl += "* " + to_expression(name_id) + "_atomic";
12738 	}
12739 
12740 	is_using_builtin_array = false;
12741 
12742 	return decl;
12743 }
12744 
12745 // If we're currently in the entry point function, and the object
12746 // has a qualified name, use it, otherwise use the standard name.
to_name(uint32_t id,bool allow_alias) const12747 string CompilerMSL::to_name(uint32_t id, bool allow_alias) const
12748 {
12749 	if (current_function && (current_function->self == ir.default_entry_point))
12750 	{
12751 		auto *m = ir.find_meta(id);
12752 		if (m && !m->decoration.qualified_alias.empty())
12753 			return m->decoration.qualified_alias;
12754 	}
12755 	return Compiler::to_name(id, allow_alias);
12756 }
12757 
12758 // Returns a name that combines the name of the struct with the name of the member, except for Builtins
to_qualified_member_name(const SPIRType & type,uint32_t index)12759 string CompilerMSL::to_qualified_member_name(const SPIRType &type, uint32_t index)
12760 {
12761 	// Don't qualify Builtin names because they are unique and are treated as such when building expressions
12762 	BuiltIn builtin = BuiltInMax;
12763 	if (is_member_builtin(type, index, &builtin))
12764 		return builtin_to_glsl(builtin, type.storage);
12765 
12766 	// Strip any underscore prefix from member name
12767 	string mbr_name = to_member_name(type, index);
12768 	size_t startPos = mbr_name.find_first_not_of("_");
12769 	mbr_name = (startPos != string::npos) ? mbr_name.substr(startPos) : "";
12770 	return join(to_name(type.self), "_", mbr_name);
12771 }
12772 
12773 // Ensures that the specified name is permanently usable by prepending a prefix
12774 // if the first chars are _ and a digit, which indicate a transient name.
ensure_valid_name(string name,string pfx)12775 string CompilerMSL::ensure_valid_name(string name, string pfx)
12776 {
12777 	return (name.size() >= 2 && name[0] == '_' && isdigit(name[1])) ? (pfx + name) : name;
12778 }
12779 
get_reserved_keyword_set()12780 const std::unordered_set<std::string> &CompilerMSL::get_reserved_keyword_set()
12781 {
12782 	static const unordered_set<string> keywords = {
12783 		"kernel",
12784 		"vertex",
12785 		"fragment",
12786 		"compute",
12787 		"bias",
12788 		"level",
12789 		"gradient2d",
12790 		"gradientcube",
12791 		"gradient3d",
12792 		"min_lod_clamp",
12793 		"assert",
12794 		"VARIABLE_TRACEPOINT",
12795 		"STATIC_DATA_TRACEPOINT",
12796 		"STATIC_DATA_TRACEPOINT_V",
12797 		"METAL_ALIGN",
12798 		"METAL_ASM",
12799 		"METAL_CONST",
12800 		"METAL_DEPRECATED",
12801 		"METAL_ENABLE_IF",
12802 		"METAL_FUNC",
12803 		"METAL_INTERNAL",
12804 		"METAL_NON_NULL_RETURN",
12805 		"METAL_NORETURN",
12806 		"METAL_NOTHROW",
12807 		"METAL_PURE",
12808 		"METAL_UNAVAILABLE",
12809 		"METAL_IMPLICIT",
12810 		"METAL_EXPLICIT",
12811 		"METAL_CONST_ARG",
12812 		"METAL_ARG_UNIFORM",
12813 		"METAL_ZERO_ARG",
12814 		"METAL_VALID_LOD_ARG",
12815 		"METAL_VALID_LEVEL_ARG",
12816 		"METAL_VALID_STORE_ORDER",
12817 		"METAL_VALID_LOAD_ORDER",
12818 		"METAL_VALID_COMPARE_EXCHANGE_FAILURE_ORDER",
12819 		"METAL_COMPATIBLE_COMPARE_EXCHANGE_ORDERS",
12820 		"METAL_VALID_RENDER_TARGET",
12821 		"is_function_constant_defined",
12822 		"CHAR_BIT",
12823 		"SCHAR_MAX",
12824 		"SCHAR_MIN",
12825 		"UCHAR_MAX",
12826 		"CHAR_MAX",
12827 		"CHAR_MIN",
12828 		"USHRT_MAX",
12829 		"SHRT_MAX",
12830 		"SHRT_MIN",
12831 		"UINT_MAX",
12832 		"INT_MAX",
12833 		"INT_MIN",
12834 		"FLT_DIG",
12835 		"FLT_MANT_DIG",
12836 		"FLT_MAX_10_EXP",
12837 		"FLT_MAX_EXP",
12838 		"FLT_MIN_10_EXP",
12839 		"FLT_MIN_EXP",
12840 		"FLT_RADIX",
12841 		"FLT_MAX",
12842 		"FLT_MIN",
12843 		"FLT_EPSILON",
12844 		"FP_ILOGB0",
12845 		"FP_ILOGBNAN",
12846 		"MAXFLOAT",
12847 		"HUGE_VALF",
12848 		"INFINITY",
12849 		"NAN",
12850 		"M_E_F",
12851 		"M_LOG2E_F",
12852 		"M_LOG10E_F",
12853 		"M_LN2_F",
12854 		"M_LN10_F",
12855 		"M_PI_F",
12856 		"M_PI_2_F",
12857 		"M_PI_4_F",
12858 		"M_1_PI_F",
12859 		"M_2_PI_F",
12860 		"M_2_SQRTPI_F",
12861 		"M_SQRT2_F",
12862 		"M_SQRT1_2_F",
12863 		"HALF_DIG",
12864 		"HALF_MANT_DIG",
12865 		"HALF_MAX_10_EXP",
12866 		"HALF_MAX_EXP",
12867 		"HALF_MIN_10_EXP",
12868 		"HALF_MIN_EXP",
12869 		"HALF_RADIX",
12870 		"HALF_MAX",
12871 		"HALF_MIN",
12872 		"HALF_EPSILON",
12873 		"MAXHALF",
12874 		"HUGE_VALH",
12875 		"M_E_H",
12876 		"M_LOG2E_H",
12877 		"M_LOG10E_H",
12878 		"M_LN2_H",
12879 		"M_LN10_H",
12880 		"M_PI_H",
12881 		"M_PI_2_H",
12882 		"M_PI_4_H",
12883 		"M_1_PI_H",
12884 		"M_2_PI_H",
12885 		"M_2_SQRTPI_H",
12886 		"M_SQRT2_H",
12887 		"M_SQRT1_2_H",
12888 		"DBL_DIG",
12889 		"DBL_MANT_DIG",
12890 		"DBL_MAX_10_EXP",
12891 		"DBL_MAX_EXP",
12892 		"DBL_MIN_10_EXP",
12893 		"DBL_MIN_EXP",
12894 		"DBL_RADIX",
12895 		"DBL_MAX",
12896 		"DBL_MIN",
12897 		"DBL_EPSILON",
12898 		"HUGE_VAL",
12899 		"M_E",
12900 		"M_LOG2E",
12901 		"M_LOG10E",
12902 		"M_LN2",
12903 		"M_LN10",
12904 		"M_PI",
12905 		"M_PI_2",
12906 		"M_PI_4",
12907 		"M_1_PI",
12908 		"M_2_PI",
12909 		"M_2_SQRTPI",
12910 		"M_SQRT2",
12911 		"M_SQRT1_2",
12912 		"quad_broadcast",
12913 	};
12914 
12915 	return keywords;
12916 }
12917 
get_illegal_func_names()12918 const std::unordered_set<std::string> &CompilerMSL::get_illegal_func_names()
12919 {
12920 	static const unordered_set<string> illegal_func_names = {
12921 		"main",
12922 		"saturate",
12923 		"assert",
12924 		"fmin3",
12925 		"fmax3",
12926 		"VARIABLE_TRACEPOINT",
12927 		"STATIC_DATA_TRACEPOINT",
12928 		"STATIC_DATA_TRACEPOINT_V",
12929 		"METAL_ALIGN",
12930 		"METAL_ASM",
12931 		"METAL_CONST",
12932 		"METAL_DEPRECATED",
12933 		"METAL_ENABLE_IF",
12934 		"METAL_FUNC",
12935 		"METAL_INTERNAL",
12936 		"METAL_NON_NULL_RETURN",
12937 		"METAL_NORETURN",
12938 		"METAL_NOTHROW",
12939 		"METAL_PURE",
12940 		"METAL_UNAVAILABLE",
12941 		"METAL_IMPLICIT",
12942 		"METAL_EXPLICIT",
12943 		"METAL_CONST_ARG",
12944 		"METAL_ARG_UNIFORM",
12945 		"METAL_ZERO_ARG",
12946 		"METAL_VALID_LOD_ARG",
12947 		"METAL_VALID_LEVEL_ARG",
12948 		"METAL_VALID_STORE_ORDER",
12949 		"METAL_VALID_LOAD_ORDER",
12950 		"METAL_VALID_COMPARE_EXCHANGE_FAILURE_ORDER",
12951 		"METAL_COMPATIBLE_COMPARE_EXCHANGE_ORDERS",
12952 		"METAL_VALID_RENDER_TARGET",
12953 		"is_function_constant_defined",
12954 		"CHAR_BIT",
12955 		"SCHAR_MAX",
12956 		"SCHAR_MIN",
12957 		"UCHAR_MAX",
12958 		"CHAR_MAX",
12959 		"CHAR_MIN",
12960 		"USHRT_MAX",
12961 		"SHRT_MAX",
12962 		"SHRT_MIN",
12963 		"UINT_MAX",
12964 		"INT_MAX",
12965 		"INT_MIN",
12966 		"FLT_DIG",
12967 		"FLT_MANT_DIG",
12968 		"FLT_MAX_10_EXP",
12969 		"FLT_MAX_EXP",
12970 		"FLT_MIN_10_EXP",
12971 		"FLT_MIN_EXP",
12972 		"FLT_RADIX",
12973 		"FLT_MAX",
12974 		"FLT_MIN",
12975 		"FLT_EPSILON",
12976 		"FP_ILOGB0",
12977 		"FP_ILOGBNAN",
12978 		"MAXFLOAT",
12979 		"HUGE_VALF",
12980 		"INFINITY",
12981 		"NAN",
12982 		"M_E_F",
12983 		"M_LOG2E_F",
12984 		"M_LOG10E_F",
12985 		"M_LN2_F",
12986 		"M_LN10_F",
12987 		"M_PI_F",
12988 		"M_PI_2_F",
12989 		"M_PI_4_F",
12990 		"M_1_PI_F",
12991 		"M_2_PI_F",
12992 		"M_2_SQRTPI_F",
12993 		"M_SQRT2_F",
12994 		"M_SQRT1_2_F",
12995 		"HALF_DIG",
12996 		"HALF_MANT_DIG",
12997 		"HALF_MAX_10_EXP",
12998 		"HALF_MAX_EXP",
12999 		"HALF_MIN_10_EXP",
13000 		"HALF_MIN_EXP",
13001 		"HALF_RADIX",
13002 		"HALF_MAX",
13003 		"HALF_MIN",
13004 		"HALF_EPSILON",
13005 		"MAXHALF",
13006 		"HUGE_VALH",
13007 		"M_E_H",
13008 		"M_LOG2E_H",
13009 		"M_LOG10E_H",
13010 		"M_LN2_H",
13011 		"M_LN10_H",
13012 		"M_PI_H",
13013 		"M_PI_2_H",
13014 		"M_PI_4_H",
13015 		"M_1_PI_H",
13016 		"M_2_PI_H",
13017 		"M_2_SQRTPI_H",
13018 		"M_SQRT2_H",
13019 		"M_SQRT1_2_H",
13020 		"DBL_DIG",
13021 		"DBL_MANT_DIG",
13022 		"DBL_MAX_10_EXP",
13023 		"DBL_MAX_EXP",
13024 		"DBL_MIN_10_EXP",
13025 		"DBL_MIN_EXP",
13026 		"DBL_RADIX",
13027 		"DBL_MAX",
13028 		"DBL_MIN",
13029 		"DBL_EPSILON",
13030 		"HUGE_VAL",
13031 		"M_E",
13032 		"M_LOG2E",
13033 		"M_LOG10E",
13034 		"M_LN2",
13035 		"M_LN10",
13036 		"M_PI",
13037 		"M_PI_2",
13038 		"M_PI_4",
13039 		"M_1_PI",
13040 		"M_2_PI",
13041 		"M_2_SQRTPI",
13042 		"M_SQRT2",
13043 		"M_SQRT1_2",
13044 	};
13045 
13046 	return illegal_func_names;
13047 }
13048 
13049 // Replace all names that match MSL keywords or Metal Standard Library functions.
replace_illegal_names()13050 void CompilerMSL::replace_illegal_names()
13051 {
13052 	// FIXME: MSL and GLSL are doing two different things here.
13053 	// Agree on convention and remove this override.
13054 	auto &keywords = get_reserved_keyword_set();
13055 	auto &illegal_func_names = get_illegal_func_names();
13056 
13057 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t self, SPIRVariable &) {
13058 		auto *meta = ir.find_meta(self);
13059 		if (!meta)
13060 			return;
13061 
13062 		auto &dec = meta->decoration;
13063 		if (keywords.find(dec.alias) != end(keywords))
13064 			dec.alias += "0";
13065 	});
13066 
13067 	ir.for_each_typed_id<SPIRFunction>([&](uint32_t self, SPIRFunction &) {
13068 		auto *meta = ir.find_meta(self);
13069 		if (!meta)
13070 			return;
13071 
13072 		auto &dec = meta->decoration;
13073 		if (illegal_func_names.find(dec.alias) != end(illegal_func_names))
13074 			dec.alias += "0";
13075 	});
13076 
13077 	ir.for_each_typed_id<SPIRType>([&](uint32_t self, SPIRType &) {
13078 		auto *meta = ir.find_meta(self);
13079 		if (!meta)
13080 			return;
13081 
13082 		for (auto &mbr_dec : meta->members)
13083 			if (keywords.find(mbr_dec.alias) != end(keywords))
13084 				mbr_dec.alias += "0";
13085 	});
13086 
13087 	CompilerGLSL::replace_illegal_names();
13088 }
13089 
replace_illegal_entry_point_names()13090 void CompilerMSL::replace_illegal_entry_point_names()
13091 {
13092 	auto &illegal_func_names = get_illegal_func_names();
13093 
13094 	// It is important to this before we fixup identifiers,
13095 	// since if ep_name is reserved, we will need to fix that up,
13096 	// and then copy alias back into entry.name after the fixup.
13097 	for (auto &entry : ir.entry_points)
13098 	{
13099 		// Change both the entry point name and the alias, to keep them synced.
13100 		string &ep_name = entry.second.name;
13101 		if (illegal_func_names.find(ep_name) != end(illegal_func_names))
13102 			ep_name += "0";
13103 
13104 		ir.meta[entry.first].decoration.alias = ep_name;
13105 	}
13106 }
13107 
sync_entry_point_aliases_and_names()13108 void CompilerMSL::sync_entry_point_aliases_and_names()
13109 {
13110 	for (auto &entry : ir.entry_points)
13111 		entry.second.name = ir.meta[entry.first].decoration.alias;
13112 }
13113 
to_member_reference(uint32_t base,const SPIRType & type,uint32_t index,bool ptr_chain)13114 string CompilerMSL::to_member_reference(uint32_t base, const SPIRType &type, uint32_t index, bool ptr_chain)
13115 {
13116 	if (index < uint32_t(type.member_type_index_redirection.size()))
13117 		index = type.member_type_index_redirection[index];
13118 
13119 	auto *var = maybe_get<SPIRVariable>(base);
13120 	// If this is a buffer array, we have to dereference the buffer pointers.
13121 	// Otherwise, if this is a pointer expression, dereference it.
13122 
13123 	bool declared_as_pointer = false;
13124 
13125 	if (var)
13126 	{
13127 		// Only allow -> dereference for block types. This is so we get expressions like
13128 		// buffer[i]->first_member.second_member, rather than buffer[i]->first->second.
13129 		bool is_block = has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock);
13130 
13131 		bool is_buffer_variable =
13132 		    is_block && (var->storage == StorageClassUniform || var->storage == StorageClassStorageBuffer);
13133 		declared_as_pointer = is_buffer_variable && is_array(get<SPIRType>(var->basetype));
13134 	}
13135 
13136 	if (declared_as_pointer || (!ptr_chain && should_dereference(base)))
13137 		return join("->", to_member_name(type, index));
13138 	else
13139 		return join(".", to_member_name(type, index));
13140 }
13141 
to_qualifiers_glsl(uint32_t id)13142 string CompilerMSL::to_qualifiers_glsl(uint32_t id)
13143 {
13144 	string quals;
13145 
13146 	auto *var = maybe_get<SPIRVariable>(id);
13147 	auto &type = expression_type(id);
13148 
13149 	if (type.storage == StorageClassWorkgroup || (var && variable_decl_is_remapped_storage(*var, StorageClassWorkgroup)))
13150 		quals += "threadgroup ";
13151 
13152 	return quals;
13153 }
13154 
13155 // The optional id parameter indicates the object whose type we are trying
13156 // to find the description for. It is optional. Most type descriptions do not
13157 // depend on a specific object's use of that type.
type_to_glsl(const SPIRType & type,uint32_t id)13158 string CompilerMSL::type_to_glsl(const SPIRType &type, uint32_t id)
13159 {
13160 	string type_name;
13161 
13162 	// Pointer?
13163 	if (type.pointer)
13164 	{
13165 		assert(type.pointer_depth > 0);
13166 
13167 		const char *restrict_kw;
13168 
13169 		auto type_address_space = get_type_address_space(type, id);
13170 		auto type_decl = type_to_glsl(get<SPIRType>(type.parent_type), id);
13171 
13172 		// Work around C pointer qualifier rules. If glsl_type is a pointer type as well
13173 		// we'll need to emit the address space to the right.
13174 		// We could always go this route, but it makes the code unnatural.
13175 		// Prefer emitting thread T *foo over T thread* foo since it's more readable,
13176 		// but we'll have to emit thread T * thread * T constant bar; for example.
13177 		if (type_is_pointer_to_pointer(type))
13178 			type_name = join(type_decl, " ", type_address_space, " ");
13179 		else
13180 			type_name = join(type_address_space, " ", type_decl);
13181 
13182 		switch (type.basetype)
13183 		{
13184 		case SPIRType::Image:
13185 		case SPIRType::SampledImage:
13186 		case SPIRType::Sampler:
13187 			// These are handles.
13188 			break;
13189 		default:
13190 			// Anything else can be a raw pointer.
13191 			type_name += "*";
13192 			restrict_kw = to_restrict(id);
13193 			if (*restrict_kw)
13194 			{
13195 				type_name += " ";
13196 				type_name += restrict_kw;
13197 			}
13198 			break;
13199 		}
13200 		return type_name;
13201 	}
13202 
13203 	switch (type.basetype)
13204 	{
13205 	case SPIRType::Struct:
13206 		// Need OpName lookup here to get a "sensible" name for a struct.
13207 		// Allow Metal to use the array<T> template to make arrays a value type
13208 		type_name = to_name(type.self);
13209 		break;
13210 
13211 	case SPIRType::Image:
13212 	case SPIRType::SampledImage:
13213 		return image_type_glsl(type, id);
13214 
13215 	case SPIRType::Sampler:
13216 		return sampler_type(type, id);
13217 
13218 	case SPIRType::Void:
13219 		return "void";
13220 
13221 	case SPIRType::AtomicCounter:
13222 		return "atomic_uint";
13223 
13224 	case SPIRType::ControlPointArray:
13225 		return join("patch_control_point<", type_to_glsl(get<SPIRType>(type.parent_type), id), ">");
13226 
13227 	case SPIRType::Interpolant:
13228 		return join("interpolant<", type_to_glsl(get<SPIRType>(type.parent_type), id), ", interpolation::",
13229 		            has_decoration(type.self, DecorationNoPerspective) ? "no_perspective" : "perspective", ">");
13230 
13231 	// Scalars
13232 	case SPIRType::Boolean:
13233 		type_name = "bool";
13234 		break;
13235 	case SPIRType::Char:
13236 	case SPIRType::SByte:
13237 		type_name = "char";
13238 		break;
13239 	case SPIRType::UByte:
13240 		type_name = "uchar";
13241 		break;
13242 	case SPIRType::Short:
13243 		type_name = "short";
13244 		break;
13245 	case SPIRType::UShort:
13246 		type_name = "ushort";
13247 		break;
13248 	case SPIRType::Int:
13249 		type_name = "int";
13250 		break;
13251 	case SPIRType::UInt:
13252 		type_name = "uint";
13253 		break;
13254 	case SPIRType::Int64:
13255 		if (!msl_options.supports_msl_version(2, 2))
13256 			SPIRV_CROSS_THROW("64-bit integers are only supported in MSL 2.2 and above.");
13257 		type_name = "long";
13258 		break;
13259 	case SPIRType::UInt64:
13260 		if (!msl_options.supports_msl_version(2, 2))
13261 			SPIRV_CROSS_THROW("64-bit integers are only supported in MSL 2.2 and above.");
13262 		type_name = "ulong";
13263 		break;
13264 	case SPIRType::Half:
13265 		type_name = "half";
13266 		break;
13267 	case SPIRType::Float:
13268 		type_name = "float";
13269 		break;
13270 	case SPIRType::Double:
13271 		type_name = "double"; // Currently unsupported
13272 		break;
13273 
13274 	default:
13275 		return "unknown_type";
13276 	}
13277 
13278 	// Matrix?
13279 	if (type.columns > 1)
13280 		type_name += to_string(type.columns) + "x";
13281 
13282 	// Vector or Matrix?
13283 	if (type.vecsize > 1)
13284 		type_name += to_string(type.vecsize);
13285 
13286 	if (type.array.empty() || using_builtin_array())
13287 	{
13288 		return type_name;
13289 	}
13290 	else
13291 	{
13292 		// Allow Metal to use the array<T> template to make arrays a value type
13293 		add_spv_func_and_recompile(SPVFuncImplUnsafeArray);
13294 		string res;
13295 		string sizes;
13296 
13297 		for (uint32_t i = 0; i < uint32_t(type.array.size()); i++)
13298 		{
13299 			res += "spvUnsafeArray<";
13300 			sizes += ", ";
13301 			sizes += to_array_size(type, i);
13302 			sizes += ">";
13303 		}
13304 
13305 		res += type_name + sizes;
13306 		return res;
13307 	}
13308 }
13309 
type_to_array_glsl(const SPIRType & type)13310 string CompilerMSL::type_to_array_glsl(const SPIRType &type)
13311 {
13312 	// Allow Metal to use the array<T> template to make arrays a value type
13313 	switch (type.basetype)
13314 	{
13315 	case SPIRType::AtomicCounter:
13316 	case SPIRType::ControlPointArray:
13317 	{
13318 		return CompilerGLSL::type_to_array_glsl(type);
13319 	}
13320 	default:
13321 	{
13322 		if (using_builtin_array())
13323 			return CompilerGLSL::type_to_array_glsl(type);
13324 		else
13325 			return "";
13326 	}
13327 	}
13328 }
13329 
variable_decl_is_remapped_storage(const SPIRVariable & variable,spv::StorageClass storage) const13330 bool CompilerMSL::variable_decl_is_remapped_storage(const SPIRVariable &variable, spv::StorageClass storage) const
13331 {
13332 	if (variable.storage == storage)
13333 		return true;
13334 
13335 	if (storage == StorageClassWorkgroup)
13336 	{
13337 		auto model = get_execution_model();
13338 
13339 		// Specially masked IO block variable.
13340 		// Normally, we will never access IO blocks directly here.
13341 		// The only scenario which that should occur is with a masked IO block.
13342 		if (model == ExecutionModelTessellationControl && variable.storage == StorageClassOutput &&
13343 		    has_decoration(get<SPIRType>(variable.basetype).self, DecorationBlock))
13344 		{
13345 			return true;
13346 		}
13347 
13348 		return variable.storage == StorageClassOutput &&
13349 		       model == ExecutionModelTessellationControl &&
13350 		       is_stage_output_variable_masked(variable);
13351 	}
13352 	else if (storage == StorageClassStorageBuffer)
13353 	{
13354 		// We won't be able to catch writes to control point outputs here since variable
13355 		// refers to a function local pointer.
13356 		// This is fine, as there cannot be concurrent writers to that memory anyways,
13357 		// so we just ignore that case.
13358 
13359 		return (variable.storage == StorageClassOutput || variable.storage == StorageClassInput) &&
13360 		       !variable_storage_requires_stage_io(variable.storage) &&
13361 		       (variable.storage != StorageClassOutput || !is_stage_output_variable_masked(variable));
13362 	}
13363 	else
13364 	{
13365 		return false;
13366 	}
13367 }
13368 
variable_decl(const SPIRVariable & variable)13369 std::string CompilerMSL::variable_decl(const SPIRVariable &variable)
13370 {
13371 	bool old_is_using_builtin_array = is_using_builtin_array;
13372 
13373 	// Threadgroup arrays can't have a wrapper type.
13374 	if (variable_decl_is_remapped_storage(variable, StorageClassWorkgroup))
13375 		is_using_builtin_array = true;
13376 
13377 	std::string expr = CompilerGLSL::variable_decl(variable);
13378 	is_using_builtin_array = old_is_using_builtin_array;
13379 	return expr;
13380 }
13381 
13382 // GCC workaround of lambdas calling protected funcs
variable_decl(const SPIRType & type,const std::string & name,uint32_t id)13383 std::string CompilerMSL::variable_decl(const SPIRType &type, const std::string &name, uint32_t id)
13384 {
13385 	return CompilerGLSL::variable_decl(type, name, id);
13386 }
13387 
sampler_type(const SPIRType & type,uint32_t id)13388 std::string CompilerMSL::sampler_type(const SPIRType &type, uint32_t id)
13389 {
13390 	auto *var = maybe_get<SPIRVariable>(id);
13391 	if (var && var->basevariable)
13392 	{
13393 		// Check against the base variable, and not a fake ID which might have been generated for this variable.
13394 		id = var->basevariable;
13395 	}
13396 
13397 	if (!type.array.empty())
13398 	{
13399 		if (!msl_options.supports_msl_version(2))
13400 			SPIRV_CROSS_THROW("MSL 2.0 or greater is required for arrays of samplers.");
13401 
13402 		if (type.array.size() > 1)
13403 			SPIRV_CROSS_THROW("Arrays of arrays of samplers are not supported in MSL.");
13404 
13405 		// Arrays of samplers in MSL must be declared with a special array<T, N> syntax ala C++11 std::array.
13406 		// If we have a runtime array, it could be a variable-count descriptor set binding.
13407 		uint32_t array_size = to_array_size_literal(type);
13408 		if (array_size == 0)
13409 			array_size = get_resource_array_size(id);
13410 
13411 		if (array_size == 0)
13412 			SPIRV_CROSS_THROW("Unsized array of samplers is not supported in MSL.");
13413 
13414 		auto &parent = get<SPIRType>(get_pointee_type(type).parent_type);
13415 		return join("array<", sampler_type(parent, id), ", ", array_size, ">");
13416 	}
13417 	else
13418 		return "sampler";
13419 }
13420 
13421 // Returns an MSL string describing the SPIR-V image type
image_type_glsl(const SPIRType & type,uint32_t id)13422 string CompilerMSL::image_type_glsl(const SPIRType &type, uint32_t id)
13423 {
13424 	auto *var = maybe_get<SPIRVariable>(id);
13425 	if (var && var->basevariable)
13426 	{
13427 		// For comparison images, check against the base variable,
13428 		// and not the fake ID which might have been generated for this variable.
13429 		id = var->basevariable;
13430 	}
13431 
13432 	if (!type.array.empty())
13433 	{
13434 		uint32_t major = 2, minor = 0;
13435 		if (msl_options.is_ios())
13436 		{
13437 			major = 1;
13438 			minor = 2;
13439 		}
13440 		if (!msl_options.supports_msl_version(major, minor))
13441 		{
13442 			if (msl_options.is_ios())
13443 				SPIRV_CROSS_THROW("MSL 1.2 or greater is required for arrays of textures.");
13444 			else
13445 				SPIRV_CROSS_THROW("MSL 2.0 or greater is required for arrays of textures.");
13446 		}
13447 
13448 		if (type.array.size() > 1)
13449 			SPIRV_CROSS_THROW("Arrays of arrays of textures are not supported in MSL.");
13450 
13451 		// Arrays of images in MSL must be declared with a special array<T, N> syntax ala C++11 std::array.
13452 		// If we have a runtime array, it could be a variable-count descriptor set binding.
13453 		uint32_t array_size = to_array_size_literal(type);
13454 		if (array_size == 0)
13455 			array_size = get_resource_array_size(id);
13456 
13457 		if (array_size == 0)
13458 			SPIRV_CROSS_THROW("Unsized array of images is not supported in MSL.");
13459 
13460 		auto &parent = get<SPIRType>(get_pointee_type(type).parent_type);
13461 		return join("array<", image_type_glsl(parent, id), ", ", array_size, ">");
13462 	}
13463 
13464 	string img_type_name;
13465 
13466 	// Bypass pointers because we need the real image struct
13467 	auto &img_type = get<SPIRType>(type.self).image;
13468 	if (image_is_comparison(type, id))
13469 	{
13470 		switch (img_type.dim)
13471 		{
13472 		case Dim1D:
13473 		case Dim2D:
13474 			if (img_type.dim == Dim1D && !msl_options.texture_1D_as_2D)
13475 			{
13476 				// Use a native Metal 1D texture
13477 				img_type_name += "depth1d_unsupported_by_metal";
13478 				break;
13479 			}
13480 
13481 			if (img_type.ms && img_type.arrayed)
13482 			{
13483 				if (!msl_options.supports_msl_version(2, 1))
13484 					SPIRV_CROSS_THROW("Multisampled array textures are supported from 2.1.");
13485 				img_type_name += "depth2d_ms_array";
13486 			}
13487 			else if (img_type.ms)
13488 				img_type_name += "depth2d_ms";
13489 			else if (img_type.arrayed)
13490 				img_type_name += "depth2d_array";
13491 			else
13492 				img_type_name += "depth2d";
13493 			break;
13494 		case Dim3D:
13495 			img_type_name += "depth3d_unsupported_by_metal";
13496 			break;
13497 		case DimCube:
13498 			if (!msl_options.emulate_cube_array)
13499 				img_type_name += (img_type.arrayed ? "depthcube_array" : "depthcube");
13500 			else
13501 				img_type_name += (img_type.arrayed ? "depth2d_array" : "depthcube");
13502 			break;
13503 		default:
13504 			img_type_name += "unknown_depth_texture_type";
13505 			break;
13506 		}
13507 	}
13508 	else
13509 	{
13510 		switch (img_type.dim)
13511 		{
13512 		case DimBuffer:
13513 			if (img_type.ms || img_type.arrayed)
13514 				SPIRV_CROSS_THROW("Cannot use texel buffers with multisampling or array layers.");
13515 
13516 			if (msl_options.texture_buffer_native)
13517 			{
13518 				if (!msl_options.supports_msl_version(2, 1))
13519 					SPIRV_CROSS_THROW("Native texture_buffer type is only supported in MSL 2.1.");
13520 				img_type_name = "texture_buffer";
13521 			}
13522 			else
13523 				img_type_name += "texture2d";
13524 			break;
13525 		case Dim1D:
13526 		case Dim2D:
13527 		case DimSubpassData:
13528 		{
13529 			bool subpass_array =
13530 			    img_type.dim == DimSubpassData && (msl_options.multiview || msl_options.arrayed_subpass_input);
13531 			if (img_type.dim == Dim1D && !msl_options.texture_1D_as_2D)
13532 			{
13533 				// Use a native Metal 1D texture
13534 				img_type_name += (img_type.arrayed ? "texture1d_array" : "texture1d");
13535 				break;
13536 			}
13537 
13538 			// Use Metal's native frame-buffer fetch API for subpass inputs.
13539 			if (type_is_msl_framebuffer_fetch(type))
13540 			{
13541 				auto img_type_4 = get<SPIRType>(img_type.type);
13542 				img_type_4.vecsize = 4;
13543 				return type_to_glsl(img_type_4);
13544 			}
13545 			if (img_type.ms && (img_type.arrayed || subpass_array))
13546 			{
13547 				if (!msl_options.supports_msl_version(2, 1))
13548 					SPIRV_CROSS_THROW("Multisampled array textures are supported from 2.1.");
13549 				img_type_name += "texture2d_ms_array";
13550 			}
13551 			else if (img_type.ms)
13552 				img_type_name += "texture2d_ms";
13553 			else if (img_type.arrayed || subpass_array)
13554 				img_type_name += "texture2d_array";
13555 			else
13556 				img_type_name += "texture2d";
13557 			break;
13558 		}
13559 		case Dim3D:
13560 			img_type_name += "texture3d";
13561 			break;
13562 		case DimCube:
13563 			if (!msl_options.emulate_cube_array)
13564 				img_type_name += (img_type.arrayed ? "texturecube_array" : "texturecube");
13565 			else
13566 				img_type_name += (img_type.arrayed ? "texture2d_array" : "texturecube");
13567 			break;
13568 		default:
13569 			img_type_name += "unknown_texture_type";
13570 			break;
13571 		}
13572 	}
13573 
13574 	// Append the pixel type
13575 	img_type_name += "<";
13576 	img_type_name += type_to_glsl(get<SPIRType>(img_type.type));
13577 
13578 	// For unsampled images, append the sample/read/write access qualifier.
13579 	// For kernel images, the access qualifier my be supplied directly by SPIR-V.
13580 	// Otherwise it may be set based on whether the image is read from or written to within the shader.
13581 	if (type.basetype == SPIRType::Image && type.image.sampled == 2 && type.image.dim != DimSubpassData)
13582 	{
13583 		switch (img_type.access)
13584 		{
13585 		case AccessQualifierReadOnly:
13586 			img_type_name += ", access::read";
13587 			break;
13588 
13589 		case AccessQualifierWriteOnly:
13590 			img_type_name += ", access::write";
13591 			break;
13592 
13593 		case AccessQualifierReadWrite:
13594 			img_type_name += ", access::read_write";
13595 			break;
13596 
13597 		default:
13598 		{
13599 			auto *p_var = maybe_get_backing_variable(id);
13600 			if (p_var && p_var->basevariable)
13601 				p_var = maybe_get<SPIRVariable>(p_var->basevariable);
13602 			if (p_var && !has_decoration(p_var->self, DecorationNonWritable))
13603 			{
13604 				img_type_name += ", access::";
13605 
13606 				if (!has_decoration(p_var->self, DecorationNonReadable))
13607 					img_type_name += "read_";
13608 
13609 				img_type_name += "write";
13610 			}
13611 			break;
13612 		}
13613 		}
13614 	}
13615 
13616 	img_type_name += ">";
13617 
13618 	return img_type_name;
13619 }
13620 
emit_subgroup_op(const Instruction & i)13621 void CompilerMSL::emit_subgroup_op(const Instruction &i)
13622 {
13623 	const uint32_t *ops = stream(i);
13624 	auto op = static_cast<Op>(i.op);
13625 
13626 	if (msl_options.emulate_subgroups)
13627 	{
13628 		// In this mode, only the GroupNonUniform cap is supported. The only op
13629 		// we need to handle, then, is OpGroupNonUniformElect.
13630 		if (op != OpGroupNonUniformElect)
13631 			SPIRV_CROSS_THROW("Subgroup emulation does not support operations other than Elect.");
13632 		// In this mode, the subgroup size is assumed to be one, so every invocation
13633 		// is elected.
13634 		emit_op(ops[0], ops[1], "true", true);
13635 		return;
13636 	}
13637 
13638 	// Metal 2.0 is required. iOS only supports quad ops on 11.0 (2.0), with
13639 	// full support in 13.0 (2.2). macOS only supports broadcast and shuffle on
13640 	// 10.13 (2.0), with full support in 10.14 (2.1).
13641 	// Note that Apple GPUs before A13 make no distinction between a quad-group
13642 	// and a SIMD-group; all SIMD-groups are quad-groups on those.
13643 	if (!msl_options.supports_msl_version(2))
13644 		SPIRV_CROSS_THROW("Subgroups are only supported in Metal 2.0 and up.");
13645 
13646 	// If we need to do implicit bitcasts, make sure we do it with the correct type.
13647 	uint32_t integer_width = get_integer_width_for_instruction(i);
13648 	auto int_type = to_signed_basetype(integer_width);
13649 	auto uint_type = to_unsigned_basetype(integer_width);
13650 
13651 	if (msl_options.is_ios() && (!msl_options.supports_msl_version(2, 3) || !msl_options.ios_use_simdgroup_functions))
13652 	{
13653 		switch (op)
13654 		{
13655 		default:
13656 			SPIRV_CROSS_THROW("Subgroup ops beyond broadcast, ballot, and shuffle on iOS require Metal 2.3 and up.");
13657 		case OpGroupNonUniformBroadcastFirst:
13658 			if (!msl_options.supports_msl_version(2, 2))
13659 				SPIRV_CROSS_THROW("BroadcastFirst on iOS requires Metal 2.2 and up.");
13660 			break;
13661 		case OpGroupNonUniformElect:
13662 			if (!msl_options.supports_msl_version(2, 2))
13663 				SPIRV_CROSS_THROW("Elect on iOS requires Metal 2.2 and up.");
13664 			break;
13665 		case OpGroupNonUniformAny:
13666 		case OpGroupNonUniformAll:
13667 		case OpGroupNonUniformAllEqual:
13668 		case OpGroupNonUniformBallot:
13669 		case OpGroupNonUniformInverseBallot:
13670 		case OpGroupNonUniformBallotBitExtract:
13671 		case OpGroupNonUniformBallotFindLSB:
13672 		case OpGroupNonUniformBallotFindMSB:
13673 		case OpGroupNonUniformBallotBitCount:
13674 			if (!msl_options.supports_msl_version(2, 2))
13675 				SPIRV_CROSS_THROW("Ballot ops on iOS requires Metal 2.2 and up.");
13676 			break;
13677 		case OpGroupNonUniformBroadcast:
13678 		case OpGroupNonUniformShuffle:
13679 		case OpGroupNonUniformShuffleXor:
13680 		case OpGroupNonUniformShuffleUp:
13681 		case OpGroupNonUniformShuffleDown:
13682 		case OpGroupNonUniformQuadSwap:
13683 		case OpGroupNonUniformQuadBroadcast:
13684 			break;
13685 		}
13686 	}
13687 
13688 	if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 1))
13689 	{
13690 		switch (op)
13691 		{
13692 		default:
13693 			SPIRV_CROSS_THROW("Subgroup ops beyond broadcast and shuffle on macOS require Metal 2.1 and up.");
13694 		case OpGroupNonUniformBroadcast:
13695 		case OpGroupNonUniformShuffle:
13696 		case OpGroupNonUniformShuffleXor:
13697 		case OpGroupNonUniformShuffleUp:
13698 		case OpGroupNonUniformShuffleDown:
13699 			break;
13700 		}
13701 	}
13702 
13703 	uint32_t result_type = ops[0];
13704 	uint32_t id = ops[1];
13705 
13706 	auto scope = static_cast<Scope>(evaluate_constant_u32(ops[2]));
13707 	if (scope != ScopeSubgroup)
13708 		SPIRV_CROSS_THROW("Only subgroup scope is supported.");
13709 
13710 	switch (op)
13711 	{
13712 	case OpGroupNonUniformElect:
13713 		if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
13714 			emit_op(result_type, id, "quad_is_first()", false);
13715 		else
13716 			emit_op(result_type, id, "simd_is_first()", false);
13717 		break;
13718 
13719 	case OpGroupNonUniformBroadcast:
13720 		emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupBroadcast");
13721 		break;
13722 
13723 	case OpGroupNonUniformBroadcastFirst:
13724 		emit_unary_func_op(result_type, id, ops[3], "spvSubgroupBroadcastFirst");
13725 		break;
13726 
13727 	case OpGroupNonUniformBallot:
13728 		emit_unary_func_op(result_type, id, ops[3], "spvSubgroupBallot");
13729 		break;
13730 
13731 	case OpGroupNonUniformInverseBallot:
13732 		emit_binary_func_op(result_type, id, ops[3], builtin_subgroup_invocation_id_id, "spvSubgroupBallotBitExtract");
13733 		break;
13734 
13735 	case OpGroupNonUniformBallotBitExtract:
13736 		emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupBallotBitExtract");
13737 		break;
13738 
13739 	case OpGroupNonUniformBallotFindLSB:
13740 		emit_binary_func_op(result_type, id, ops[3], builtin_subgroup_size_id, "spvSubgroupBallotFindLSB");
13741 		break;
13742 
13743 	case OpGroupNonUniformBallotFindMSB:
13744 		emit_binary_func_op(result_type, id, ops[3], builtin_subgroup_size_id, "spvSubgroupBallotFindMSB");
13745 		break;
13746 
13747 	case OpGroupNonUniformBallotBitCount:
13748 	{
13749 		auto operation = static_cast<GroupOperation>(ops[3]);
13750 		switch (operation)
13751 		{
13752 		case GroupOperationReduce:
13753 			emit_binary_func_op(result_type, id, ops[4], builtin_subgroup_size_id, "spvSubgroupBallotBitCount");
13754 			break;
13755 		case GroupOperationInclusiveScan:
13756 			emit_binary_func_op(result_type, id, ops[4], builtin_subgroup_invocation_id_id,
13757 			                    "spvSubgroupBallotInclusiveBitCount");
13758 			break;
13759 		case GroupOperationExclusiveScan:
13760 			emit_binary_func_op(result_type, id, ops[4], builtin_subgroup_invocation_id_id,
13761 			                    "spvSubgroupBallotExclusiveBitCount");
13762 			break;
13763 		default:
13764 			SPIRV_CROSS_THROW("Invalid BitCount operation.");
13765 		}
13766 		break;
13767 	}
13768 
13769 	case OpGroupNonUniformShuffle:
13770 		emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupShuffle");
13771 		break;
13772 
13773 	case OpGroupNonUniformShuffleXor:
13774 		emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupShuffleXor");
13775 		break;
13776 
13777 	case OpGroupNonUniformShuffleUp:
13778 		emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupShuffleUp");
13779 		break;
13780 
13781 	case OpGroupNonUniformShuffleDown:
13782 		emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupShuffleDown");
13783 		break;
13784 
13785 	case OpGroupNonUniformAll:
13786 		if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
13787 			emit_unary_func_op(result_type, id, ops[3], "quad_all");
13788 		else
13789 			emit_unary_func_op(result_type, id, ops[3], "simd_all");
13790 		break;
13791 
13792 	case OpGroupNonUniformAny:
13793 		if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
13794 			emit_unary_func_op(result_type, id, ops[3], "quad_any");
13795 		else
13796 			emit_unary_func_op(result_type, id, ops[3], "simd_any");
13797 		break;
13798 
13799 	case OpGroupNonUniformAllEqual:
13800 		emit_unary_func_op(result_type, id, ops[3], "spvSubgroupAllEqual");
13801 		break;
13802 
13803 		// clang-format off
13804 #define MSL_GROUP_OP(op, msl_op) \
13805 case OpGroupNonUniform##op: \
13806 	{ \
13807 		auto operation = static_cast<GroupOperation>(ops[3]); \
13808 		if (operation == GroupOperationReduce) \
13809 			emit_unary_func_op(result_type, id, ops[4], "simd_" #msl_op); \
13810 		else if (operation == GroupOperationInclusiveScan) \
13811 			emit_unary_func_op(result_type, id, ops[4], "simd_prefix_inclusive_" #msl_op); \
13812 		else if (operation == GroupOperationExclusiveScan) \
13813 			emit_unary_func_op(result_type, id, ops[4], "simd_prefix_exclusive_" #msl_op); \
13814 		else if (operation == GroupOperationClusteredReduce) \
13815 		{ \
13816 			/* Only cluster sizes of 4 are supported. */ \
13817 			uint32_t cluster_size = evaluate_constant_u32(ops[5]); \
13818 			if (cluster_size != 4) \
13819 				SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
13820 			emit_unary_func_op(result_type, id, ops[4], "quad_" #msl_op); \
13821 		} \
13822 		else \
13823 			SPIRV_CROSS_THROW("Invalid group operation."); \
13824 		break; \
13825 	}
13826 	MSL_GROUP_OP(FAdd, sum)
13827 	MSL_GROUP_OP(FMul, product)
13828 	MSL_GROUP_OP(IAdd, sum)
13829 	MSL_GROUP_OP(IMul, product)
13830 #undef MSL_GROUP_OP
13831 	// The others, unfortunately, don't support InclusiveScan or ExclusiveScan.
13832 
13833 #define MSL_GROUP_OP(op, msl_op) \
13834 case OpGroupNonUniform##op: \
13835 	{ \
13836 		auto operation = static_cast<GroupOperation>(ops[3]); \
13837 		if (operation == GroupOperationReduce) \
13838 			emit_unary_func_op(result_type, id, ops[4], "simd_" #msl_op); \
13839 		else if (operation == GroupOperationInclusiveScan) \
13840 			SPIRV_CROSS_THROW("Metal doesn't support InclusiveScan for OpGroupNonUniform" #op "."); \
13841 		else if (operation == GroupOperationExclusiveScan) \
13842 			SPIRV_CROSS_THROW("Metal doesn't support ExclusiveScan for OpGroupNonUniform" #op "."); \
13843 		else if (operation == GroupOperationClusteredReduce) \
13844 		{ \
13845 			/* Only cluster sizes of 4 are supported. */ \
13846 			uint32_t cluster_size = evaluate_constant_u32(ops[5]); \
13847 			if (cluster_size != 4) \
13848 				SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
13849 			emit_unary_func_op(result_type, id, ops[4], "quad_" #msl_op); \
13850 		} \
13851 		else \
13852 			SPIRV_CROSS_THROW("Invalid group operation."); \
13853 		break; \
13854 	}
13855 
13856 #define MSL_GROUP_OP_CAST(op, msl_op, type) \
13857 case OpGroupNonUniform##op: \
13858 	{ \
13859 		auto operation = static_cast<GroupOperation>(ops[3]); \
13860 		if (operation == GroupOperationReduce) \
13861 			emit_unary_func_op_cast(result_type, id, ops[4], "simd_" #msl_op, type, type); \
13862 		else if (operation == GroupOperationInclusiveScan) \
13863 			SPIRV_CROSS_THROW("Metal doesn't support InclusiveScan for OpGroupNonUniform" #op "."); \
13864 		else if (operation == GroupOperationExclusiveScan) \
13865 			SPIRV_CROSS_THROW("Metal doesn't support ExclusiveScan for OpGroupNonUniform" #op "."); \
13866 		else if (operation == GroupOperationClusteredReduce) \
13867 		{ \
13868 			/* Only cluster sizes of 4 are supported. */ \
13869 			uint32_t cluster_size = evaluate_constant_u32(ops[5]); \
13870 			if (cluster_size != 4) \
13871 				SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
13872 			emit_unary_func_op_cast(result_type, id, ops[4], "quad_" #msl_op, type, type); \
13873 		} \
13874 		else \
13875 			SPIRV_CROSS_THROW("Invalid group operation."); \
13876 		break; \
13877 	}
13878 
13879 	MSL_GROUP_OP(FMin, min)
13880 	MSL_GROUP_OP(FMax, max)
13881 	MSL_GROUP_OP_CAST(SMin, min, int_type)
13882 	MSL_GROUP_OP_CAST(SMax, max, int_type)
13883 	MSL_GROUP_OP_CAST(UMin, min, uint_type)
13884 	MSL_GROUP_OP_CAST(UMax, max, uint_type)
13885 	MSL_GROUP_OP(BitwiseAnd, and)
13886 	MSL_GROUP_OP(BitwiseOr, or)
13887 	MSL_GROUP_OP(BitwiseXor, xor)
13888 	MSL_GROUP_OP(LogicalAnd, and)
13889 	MSL_GROUP_OP(LogicalOr, or)
13890 	MSL_GROUP_OP(LogicalXor, xor)
13891 		// clang-format on
13892 #undef MSL_GROUP_OP
13893 #undef MSL_GROUP_OP_CAST
13894 
13895 	case OpGroupNonUniformQuadSwap:
13896 		emit_binary_func_op(result_type, id, ops[3], ops[4], "spvQuadSwap");
13897 		break;
13898 
13899 	case OpGroupNonUniformQuadBroadcast:
13900 		emit_binary_func_op(result_type, id, ops[3], ops[4], "spvQuadBroadcast");
13901 		break;
13902 
13903 	default:
13904 		SPIRV_CROSS_THROW("Invalid opcode for subgroup.");
13905 	}
13906 
13907 	register_control_dependent_expression(id);
13908 }
13909 
bitcast_glsl_op(const SPIRType & out_type,const SPIRType & in_type)13910 string CompilerMSL::bitcast_glsl_op(const SPIRType &out_type, const SPIRType &in_type)
13911 {
13912 	if (out_type.basetype == in_type.basetype)
13913 		return "";
13914 
13915 	assert(out_type.basetype != SPIRType::Boolean);
13916 	assert(in_type.basetype != SPIRType::Boolean);
13917 
13918 	bool integral_cast = type_is_integral(out_type) && type_is_integral(in_type) && (out_type.vecsize == in_type.vecsize);
13919 	bool same_size_cast = (out_type.width * out_type.vecsize) == (in_type.width * in_type.vecsize);
13920 
13921 	// Bitcasting can only be used between types of the same overall size.
13922 	// And always formally cast between integers, because it's trivial, and also
13923 	// because Metal can internally cast the results of some integer ops to a larger
13924 	// size (eg. short shift right becomes int), which means chaining integer ops
13925 	// together may introduce size variations that SPIR-V doesn't know about.
13926 	if (same_size_cast && !integral_cast)
13927 	{
13928 		return "as_type<" + type_to_glsl(out_type) + ">";
13929 	}
13930 	else
13931 	{
13932 		return type_to_glsl(out_type);
13933 	}
13934 }
13935 
emit_complex_bitcast(uint32_t,uint32_t,uint32_t)13936 bool CompilerMSL::emit_complex_bitcast(uint32_t, uint32_t, uint32_t)
13937 {
13938 	return false;
13939 }
13940 
13941 // Returns an MSL string identifying the name of a SPIR-V builtin.
13942 // Output builtins are qualified with the name of the stage out structure.
builtin_to_glsl(BuiltIn builtin,StorageClass storage)13943 string CompilerMSL::builtin_to_glsl(BuiltIn builtin, StorageClass storage)
13944 {
13945 	switch (builtin)
13946 	{
13947 	// Handle HLSL-style 0-based vertex/instance index.
13948 	// Override GLSL compiler strictness
13949 	case BuiltInVertexId:
13950 		ensure_builtin(StorageClassInput, BuiltInVertexId);
13951 		if (msl_options.enable_base_index_zero && msl_options.supports_msl_version(1, 1) &&
13952 		    (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
13953 		{
13954 			if (builtin_declaration)
13955 			{
13956 				if (needs_base_vertex_arg != TriState::No)
13957 					needs_base_vertex_arg = TriState::Yes;
13958 				return "gl_VertexID";
13959 			}
13960 			else
13961 			{
13962 				ensure_builtin(StorageClassInput, BuiltInBaseVertex);
13963 				return "(gl_VertexID - gl_BaseVertex)";
13964 			}
13965 		}
13966 		else
13967 		{
13968 			return "gl_VertexID";
13969 		}
13970 	case BuiltInInstanceId:
13971 		ensure_builtin(StorageClassInput, BuiltInInstanceId);
13972 		if (msl_options.enable_base_index_zero && msl_options.supports_msl_version(1, 1) &&
13973 		    (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
13974 		{
13975 			if (builtin_declaration)
13976 			{
13977 				if (needs_base_instance_arg != TriState::No)
13978 					needs_base_instance_arg = TriState::Yes;
13979 				return "gl_InstanceID";
13980 			}
13981 			else
13982 			{
13983 				ensure_builtin(StorageClassInput, BuiltInBaseInstance);
13984 				return "(gl_InstanceID - gl_BaseInstance)";
13985 			}
13986 		}
13987 		else
13988 		{
13989 			return "gl_InstanceID";
13990 		}
13991 	case BuiltInVertexIndex:
13992 		ensure_builtin(StorageClassInput, BuiltInVertexIndex);
13993 		if (msl_options.enable_base_index_zero && msl_options.supports_msl_version(1, 1) &&
13994 		    (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
13995 		{
13996 			if (builtin_declaration)
13997 			{
13998 				if (needs_base_vertex_arg != TriState::No)
13999 					needs_base_vertex_arg = TriState::Yes;
14000 				return "gl_VertexIndex";
14001 			}
14002 			else
14003 			{
14004 				ensure_builtin(StorageClassInput, BuiltInBaseVertex);
14005 				return "(gl_VertexIndex - gl_BaseVertex)";
14006 			}
14007 		}
14008 		else
14009 		{
14010 			return "gl_VertexIndex";
14011 		}
14012 	case BuiltInInstanceIndex:
14013 		ensure_builtin(StorageClassInput, BuiltInInstanceIndex);
14014 		if (msl_options.enable_base_index_zero && msl_options.supports_msl_version(1, 1) &&
14015 		    (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
14016 		{
14017 			if (builtin_declaration)
14018 			{
14019 				if (needs_base_instance_arg != TriState::No)
14020 					needs_base_instance_arg = TriState::Yes;
14021 				return "gl_InstanceIndex";
14022 			}
14023 			else
14024 			{
14025 				ensure_builtin(StorageClassInput, BuiltInBaseInstance);
14026 				return "(gl_InstanceIndex - gl_BaseInstance)";
14027 			}
14028 		}
14029 		else
14030 		{
14031 			return "gl_InstanceIndex";
14032 		}
14033 	case BuiltInBaseVertex:
14034 		if (msl_options.supports_msl_version(1, 1) &&
14035 		    (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
14036 		{
14037 			needs_base_vertex_arg = TriState::No;
14038 			return "gl_BaseVertex";
14039 		}
14040 		else
14041 		{
14042 			SPIRV_CROSS_THROW("BaseVertex requires Metal 1.1 and Mac or Apple A9+ hardware.");
14043 		}
14044 	case BuiltInBaseInstance:
14045 		if (msl_options.supports_msl_version(1, 1) &&
14046 		    (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
14047 		{
14048 			needs_base_instance_arg = TriState::No;
14049 			return "gl_BaseInstance";
14050 		}
14051 		else
14052 		{
14053 			SPIRV_CROSS_THROW("BaseInstance requires Metal 1.1 and Mac or Apple A9+ hardware.");
14054 		}
14055 	case BuiltInDrawIndex:
14056 		SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
14057 
14058 	// When used in the entry function, output builtins are qualified with output struct name.
14059 	// Test storage class as NOT Input, as output builtins might be part of generic type.
14060 	// Also don't do this for tessellation control shaders.
14061 	case BuiltInViewportIndex:
14062 		if (!msl_options.supports_msl_version(2, 0))
14063 			SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
14064 		/* fallthrough */
14065 	case BuiltInFragDepth:
14066 	case BuiltInFragStencilRefEXT:
14067 		if ((builtin == BuiltInFragDepth && !msl_options.enable_frag_depth_builtin) ||
14068 		    (builtin == BuiltInFragStencilRefEXT && !msl_options.enable_frag_stencil_ref_builtin))
14069 			break;
14070 		/* fallthrough */
14071 	case BuiltInPosition:
14072 	case BuiltInPointSize:
14073 	case BuiltInClipDistance:
14074 	case BuiltInCullDistance:
14075 	case BuiltInLayer:
14076 		if (get_execution_model() == ExecutionModelTessellationControl)
14077 			break;
14078 		if (storage != StorageClassInput && current_function && (current_function->self == ir.default_entry_point) &&
14079 		    !is_stage_output_builtin_masked(builtin))
14080 			return stage_out_var_name + "." + CompilerGLSL::builtin_to_glsl(builtin, storage);
14081 		break;
14082 
14083 	case BuiltInSampleMask:
14084 		if (storage == StorageClassInput && current_function && (current_function->self == ir.default_entry_point) &&
14085 			(has_additional_fixed_sample_mask() || needs_sample_id))
14086 		{
14087 			string samp_mask_in;
14088 			samp_mask_in += "(" + CompilerGLSL::builtin_to_glsl(builtin, storage);
14089 			if (has_additional_fixed_sample_mask())
14090 				samp_mask_in += " & " + additional_fixed_sample_mask_str();
14091 			if (needs_sample_id)
14092 				samp_mask_in += " & (1 << gl_SampleID)";
14093 			samp_mask_in += ")";
14094 			return samp_mask_in;
14095 		}
14096 		if (storage != StorageClassInput && current_function && (current_function->self == ir.default_entry_point) &&
14097 		    !is_stage_output_builtin_masked(builtin))
14098 			return stage_out_var_name + "." + CompilerGLSL::builtin_to_glsl(builtin, storage);
14099 		break;
14100 
14101 	case BuiltInBaryCoordNV:
14102 	case BuiltInBaryCoordNoPerspNV:
14103 		if (storage == StorageClassInput && current_function && (current_function->self == ir.default_entry_point))
14104 			return stage_in_var_name + "." + CompilerGLSL::builtin_to_glsl(builtin, storage);
14105 		break;
14106 
14107 	case BuiltInTessLevelOuter:
14108 		if (get_execution_model() == ExecutionModelTessellationControl &&
14109 		    storage != StorageClassInput && current_function && (current_function->self == ir.default_entry_point))
14110 		{
14111 			return join(tess_factor_buffer_var_name, "[", to_expression(builtin_primitive_id_id),
14112 			            "].edgeTessellationFactor");
14113 		}
14114 		break;
14115 
14116 	case BuiltInTessLevelInner:
14117 		if (get_execution_model() == ExecutionModelTessellationControl &&
14118 		    storage != StorageClassInput && current_function && (current_function->self == ir.default_entry_point))
14119 		{
14120 			return join(tess_factor_buffer_var_name, "[", to_expression(builtin_primitive_id_id),
14121 			            "].insideTessellationFactor");
14122 		}
14123 		break;
14124 
14125 	default:
14126 		break;
14127 	}
14128 
14129 	return CompilerGLSL::builtin_to_glsl(builtin, storage);
14130 }
14131 
14132 // Returns an MSL string attribute qualifer for a SPIR-V builtin
builtin_qualifier(BuiltIn builtin)14133 string CompilerMSL::builtin_qualifier(BuiltIn builtin)
14134 {
14135 	auto &execution = get_entry_point();
14136 
14137 	switch (builtin)
14138 	{
14139 	// Vertex function in
14140 	case BuiltInVertexId:
14141 		return "vertex_id";
14142 	case BuiltInVertexIndex:
14143 		return "vertex_id";
14144 	case BuiltInBaseVertex:
14145 		return "base_vertex";
14146 	case BuiltInInstanceId:
14147 		return "instance_id";
14148 	case BuiltInInstanceIndex:
14149 		return "instance_id";
14150 	case BuiltInBaseInstance:
14151 		return "base_instance";
14152 	case BuiltInDrawIndex:
14153 		SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
14154 
14155 	// Vertex function out
14156 	case BuiltInClipDistance:
14157 		return "clip_distance";
14158 	case BuiltInPointSize:
14159 		return "point_size";
14160 	case BuiltInPosition:
14161 		if (position_invariant)
14162 		{
14163 			if (!msl_options.supports_msl_version(2, 1))
14164 				SPIRV_CROSS_THROW("Invariant position is only supported on MSL 2.1 and up.");
14165 			return "position, invariant";
14166 		}
14167 		else
14168 			return "position";
14169 	case BuiltInLayer:
14170 		return "render_target_array_index";
14171 	case BuiltInViewportIndex:
14172 		if (!msl_options.supports_msl_version(2, 0))
14173 			SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
14174 		return "viewport_array_index";
14175 
14176 	// Tess. control function in
14177 	case BuiltInInvocationId:
14178 		if (msl_options.multi_patch_workgroup)
14179 		{
14180 			// Shouldn't be reached.
14181 			SPIRV_CROSS_THROW("InvocationId is computed manually with multi-patch workgroups in MSL.");
14182 		}
14183 		return "thread_index_in_threadgroup";
14184 	case BuiltInPatchVertices:
14185 		// Shouldn't be reached.
14186 		SPIRV_CROSS_THROW("PatchVertices is derived from the auxiliary buffer in MSL.");
14187 	case BuiltInPrimitiveId:
14188 		switch (execution.model)
14189 		{
14190 		case ExecutionModelTessellationControl:
14191 			if (msl_options.multi_patch_workgroup)
14192 			{
14193 				// Shouldn't be reached.
14194 				SPIRV_CROSS_THROW("PrimitiveId is computed manually with multi-patch workgroups in MSL.");
14195 			}
14196 			return "threadgroup_position_in_grid";
14197 		case ExecutionModelTessellationEvaluation:
14198 			return "patch_id";
14199 		case ExecutionModelFragment:
14200 			if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 3))
14201 				SPIRV_CROSS_THROW("PrimitiveId on iOS requires MSL 2.3.");
14202 			else if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 2))
14203 				SPIRV_CROSS_THROW("PrimitiveId on macOS requires MSL 2.2.");
14204 			return "primitive_id";
14205 		default:
14206 			SPIRV_CROSS_THROW("PrimitiveId is not supported in this execution model.");
14207 		}
14208 
14209 	// Tess. control function out
14210 	case BuiltInTessLevelOuter:
14211 	case BuiltInTessLevelInner:
14212 		// Shouldn't be reached.
14213 		SPIRV_CROSS_THROW("Tessellation levels are handled specially in MSL.");
14214 
14215 	// Tess. evaluation function in
14216 	case BuiltInTessCoord:
14217 		return "position_in_patch";
14218 
14219 	// Fragment function in
14220 	case BuiltInFrontFacing:
14221 		return "front_facing";
14222 	case BuiltInPointCoord:
14223 		return "point_coord";
14224 	case BuiltInFragCoord:
14225 		return "position";
14226 	case BuiltInSampleId:
14227 		return "sample_id";
14228 	case BuiltInSampleMask:
14229 		return "sample_mask";
14230 	case BuiltInSamplePosition:
14231 		// Shouldn't be reached.
14232 		SPIRV_CROSS_THROW("Sample position is retrieved by a function in MSL.");
14233 	case BuiltInViewIndex:
14234 		if (execution.model != ExecutionModelFragment)
14235 			SPIRV_CROSS_THROW("ViewIndex is handled specially outside fragment shaders.");
14236 		// The ViewIndex was implicitly used in the prior stages to set the render_target_array_index,
14237 		// so we can get it from there.
14238 		return "render_target_array_index";
14239 
14240 	// Fragment function out
14241 	case BuiltInFragDepth:
14242 		if (execution.flags.get(ExecutionModeDepthGreater))
14243 			return "depth(greater)";
14244 		else if (execution.flags.get(ExecutionModeDepthLess))
14245 			return "depth(less)";
14246 		else
14247 			return "depth(any)";
14248 
14249 	case BuiltInFragStencilRefEXT:
14250 		return "stencil";
14251 
14252 	// Compute function in
14253 	case BuiltInGlobalInvocationId:
14254 		return "thread_position_in_grid";
14255 
14256 	case BuiltInWorkgroupId:
14257 		return "threadgroup_position_in_grid";
14258 
14259 	case BuiltInNumWorkgroups:
14260 		return "threadgroups_per_grid";
14261 
14262 	case BuiltInLocalInvocationId:
14263 		return "thread_position_in_threadgroup";
14264 
14265 	case BuiltInLocalInvocationIndex:
14266 		return "thread_index_in_threadgroup";
14267 
14268 	case BuiltInSubgroupSize:
14269 		if (msl_options.emulate_subgroups || msl_options.fixed_subgroup_size != 0)
14270 			// Shouldn't be reached.
14271 			SPIRV_CROSS_THROW("Emitting threads_per_simdgroup attribute with fixed subgroup size??");
14272 		if (execution.model == ExecutionModelFragment)
14273 		{
14274 			if (!msl_options.supports_msl_version(2, 2))
14275 				SPIRV_CROSS_THROW("threads_per_simdgroup requires Metal 2.2 in fragment shaders.");
14276 			return "threads_per_simdgroup";
14277 		}
14278 		else
14279 		{
14280 			// thread_execution_width is an alias for threads_per_simdgroup, and it's only available since 1.0,
14281 			// but not in fragment.
14282 			return "thread_execution_width";
14283 		}
14284 
14285 	case BuiltInNumSubgroups:
14286 		if (msl_options.emulate_subgroups)
14287 			// Shouldn't be reached.
14288 			SPIRV_CROSS_THROW("NumSubgroups is handled specially with emulation.");
14289 		if (!msl_options.supports_msl_version(2))
14290 			SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0.");
14291 		return msl_options.is_ios() ? "quadgroups_per_threadgroup" : "simdgroups_per_threadgroup";
14292 
14293 	case BuiltInSubgroupId:
14294 		if (msl_options.emulate_subgroups)
14295 			// Shouldn't be reached.
14296 			SPIRV_CROSS_THROW("SubgroupId is handled specially with emulation.");
14297 		if (!msl_options.supports_msl_version(2))
14298 			SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0.");
14299 		return msl_options.is_ios() ? "quadgroup_index_in_threadgroup" : "simdgroup_index_in_threadgroup";
14300 
14301 	case BuiltInSubgroupLocalInvocationId:
14302 		if (msl_options.emulate_subgroups)
14303 			// Shouldn't be reached.
14304 			SPIRV_CROSS_THROW("SubgroupLocalInvocationId is handled specially with emulation.");
14305 		if (execution.model == ExecutionModelFragment)
14306 		{
14307 			if (!msl_options.supports_msl_version(2, 2))
14308 				SPIRV_CROSS_THROW("thread_index_in_simdgroup requires Metal 2.2 in fragment shaders.");
14309 			return "thread_index_in_simdgroup";
14310 		}
14311 		else if (execution.model == ExecutionModelKernel || execution.model == ExecutionModelGLCompute ||
14312 		         execution.model == ExecutionModelTessellationControl ||
14313 		         (execution.model == ExecutionModelVertex && msl_options.vertex_for_tessellation))
14314 		{
14315 			// We are generating a Metal kernel function.
14316 			if (!msl_options.supports_msl_version(2))
14317 				SPIRV_CROSS_THROW("Subgroup builtins in kernel functions require Metal 2.0.");
14318 			return msl_options.is_ios() ? "thread_index_in_quadgroup" : "thread_index_in_simdgroup";
14319 		}
14320 		else
14321 			SPIRV_CROSS_THROW("Subgroup builtins are not available in this type of function.");
14322 
14323 	case BuiltInSubgroupEqMask:
14324 	case BuiltInSubgroupGeMask:
14325 	case BuiltInSubgroupGtMask:
14326 	case BuiltInSubgroupLeMask:
14327 	case BuiltInSubgroupLtMask:
14328 		// Shouldn't be reached.
14329 		SPIRV_CROSS_THROW("Subgroup ballot masks are handled specially in MSL.");
14330 
14331 	case BuiltInBaryCoordNV:
14332 		// TODO: AMD barycentrics as well? Seem to have different swizzle and 2 components rather than 3.
14333 		if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 3))
14334 			SPIRV_CROSS_THROW("Barycentrics are only supported in MSL 2.3 and above on iOS.");
14335 		else if (!msl_options.supports_msl_version(2, 2))
14336 			SPIRV_CROSS_THROW("Barycentrics are only supported in MSL 2.2 and above on macOS.");
14337 		return "barycentric_coord, center_perspective";
14338 
14339 	case BuiltInBaryCoordNoPerspNV:
14340 		// TODO: AMD barycentrics as well? Seem to have different swizzle and 2 components rather than 3.
14341 		if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 3))
14342 			SPIRV_CROSS_THROW("Barycentrics are only supported in MSL 2.3 and above on iOS.");
14343 		else if (!msl_options.supports_msl_version(2, 2))
14344 			SPIRV_CROSS_THROW("Barycentrics are only supported in MSL 2.2 and above on macOS.");
14345 		return "barycentric_coord, center_no_perspective";
14346 
14347 	default:
14348 		return "unsupported-built-in";
14349 	}
14350 }
14351 
14352 // Returns an MSL string type declaration for a SPIR-V builtin
builtin_type_decl(BuiltIn builtin,uint32_t id)14353 string CompilerMSL::builtin_type_decl(BuiltIn builtin, uint32_t id)
14354 {
14355 	const SPIREntryPoint &execution = get_entry_point();
14356 	switch (builtin)
14357 	{
14358 	// Vertex function in
14359 	case BuiltInVertexId:
14360 		return "uint";
14361 	case BuiltInVertexIndex:
14362 		return "uint";
14363 	case BuiltInBaseVertex:
14364 		return "uint";
14365 	case BuiltInInstanceId:
14366 		return "uint";
14367 	case BuiltInInstanceIndex:
14368 		return "uint";
14369 	case BuiltInBaseInstance:
14370 		return "uint";
14371 	case BuiltInDrawIndex:
14372 		SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
14373 
14374 	// Vertex function out
14375 	case BuiltInClipDistance:
14376 	case BuiltInCullDistance:
14377 		return "float";
14378 	case BuiltInPointSize:
14379 		return "float";
14380 	case BuiltInPosition:
14381 		return "float4";
14382 	case BuiltInLayer:
14383 		return "uint";
14384 	case BuiltInViewportIndex:
14385 		if (!msl_options.supports_msl_version(2, 0))
14386 			SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
14387 		return "uint";
14388 
14389 	// Tess. control function in
14390 	case BuiltInInvocationId:
14391 		return "uint";
14392 	case BuiltInPatchVertices:
14393 		return "uint";
14394 	case BuiltInPrimitiveId:
14395 		return "uint";
14396 
14397 	// Tess. control function out
14398 	case BuiltInTessLevelInner:
14399 		if (execution.model == ExecutionModelTessellationEvaluation)
14400 			return !execution.flags.get(ExecutionModeTriangles) ? "float2" : "float";
14401 		return "half";
14402 	case BuiltInTessLevelOuter:
14403 		if (execution.model == ExecutionModelTessellationEvaluation)
14404 			return !execution.flags.get(ExecutionModeTriangles) ? "float4" : "float";
14405 		return "half";
14406 
14407 	// Tess. evaluation function in
14408 	case BuiltInTessCoord:
14409 		return execution.flags.get(ExecutionModeTriangles) ? "float3" : "float2";
14410 
14411 	// Fragment function in
14412 	case BuiltInFrontFacing:
14413 		return "bool";
14414 	case BuiltInPointCoord:
14415 		return "float2";
14416 	case BuiltInFragCoord:
14417 		return "float4";
14418 	case BuiltInSampleId:
14419 		return "uint";
14420 	case BuiltInSampleMask:
14421 		return "uint";
14422 	case BuiltInSamplePosition:
14423 		return "float2";
14424 	case BuiltInViewIndex:
14425 		return "uint";
14426 
14427 	case BuiltInHelperInvocation:
14428 		return "bool";
14429 
14430 	case BuiltInBaryCoordNV:
14431 	case BuiltInBaryCoordNoPerspNV:
14432 		// Use the type as declared, can be 1, 2 or 3 components.
14433 		return type_to_glsl(get_variable_data_type(get<SPIRVariable>(id)));
14434 
14435 	// Fragment function out
14436 	case BuiltInFragDepth:
14437 		return "float";
14438 
14439 	case BuiltInFragStencilRefEXT:
14440 		return "uint";
14441 
14442 	// Compute function in
14443 	case BuiltInGlobalInvocationId:
14444 	case BuiltInLocalInvocationId:
14445 	case BuiltInNumWorkgroups:
14446 	case BuiltInWorkgroupId:
14447 		return "uint3";
14448 	case BuiltInLocalInvocationIndex:
14449 	case BuiltInNumSubgroups:
14450 	case BuiltInSubgroupId:
14451 	case BuiltInSubgroupSize:
14452 	case BuiltInSubgroupLocalInvocationId:
14453 		return "uint";
14454 	case BuiltInSubgroupEqMask:
14455 	case BuiltInSubgroupGeMask:
14456 	case BuiltInSubgroupGtMask:
14457 	case BuiltInSubgroupLeMask:
14458 	case BuiltInSubgroupLtMask:
14459 		return "uint4";
14460 
14461 	case BuiltInDeviceIndex:
14462 		return "int";
14463 
14464 	default:
14465 		return "unsupported-built-in-type";
14466 	}
14467 }
14468 
14469 // Returns the declaration of a built-in argument to a function
built_in_func_arg(BuiltIn builtin,bool prefix_comma)14470 string CompilerMSL::built_in_func_arg(BuiltIn builtin, bool prefix_comma)
14471 {
14472 	string bi_arg;
14473 	if (prefix_comma)
14474 		bi_arg += ", ";
14475 
14476 	// Handle HLSL-style 0-based vertex/instance index.
14477 	builtin_declaration = true;
14478 	bi_arg += builtin_type_decl(builtin);
14479 	bi_arg += " " + builtin_to_glsl(builtin, StorageClassInput);
14480 	bi_arg += " [[" + builtin_qualifier(builtin) + "]]";
14481 	builtin_declaration = false;
14482 
14483 	return bi_arg;
14484 }
14485 
get_physical_member_type(const SPIRType & type,uint32_t index) const14486 const SPIRType &CompilerMSL::get_physical_member_type(const SPIRType &type, uint32_t index) const
14487 {
14488 	if (member_is_remapped_physical_type(type, index))
14489 		return get<SPIRType>(get_extended_member_decoration(type.self, index, SPIRVCrossDecorationPhysicalTypeID));
14490 	else
14491 		return get<SPIRType>(type.member_types[index]);
14492 }
14493 
get_presumed_input_type(const SPIRType & ib_type,uint32_t index) const14494 SPIRType CompilerMSL::get_presumed_input_type(const SPIRType &ib_type, uint32_t index) const
14495 {
14496 	SPIRType type = get_physical_member_type(ib_type, index);
14497 	uint32_t loc = get_member_decoration(ib_type.self, index, DecorationLocation);
14498 	if (inputs_by_location.count(loc))
14499 	{
14500 		if (inputs_by_location.at(loc).vecsize > type.vecsize)
14501 			type.vecsize = inputs_by_location.at(loc).vecsize;
14502 	}
14503 	return type;
14504 }
14505 
get_declared_type_array_stride_msl(const SPIRType & type,bool is_packed,bool row_major) const14506 uint32_t CompilerMSL::get_declared_type_array_stride_msl(const SPIRType &type, bool is_packed, bool row_major) const
14507 {
14508 	// Array stride in MSL is always size * array_size. sizeof(float3) == 16,
14509 	// unlike GLSL and HLSL where array stride would be 16 and size 12.
14510 
14511 	// We could use parent type here and recurse, but that makes creating physical type remappings
14512 	// far more complicated. We'd rather just create the final type, and ignore having to create the entire type
14513 	// hierarchy in order to compute this value, so make a temporary type on the stack.
14514 
14515 	auto basic_type = type;
14516 	basic_type.array.clear();
14517 	basic_type.array_size_literal.clear();
14518 	uint32_t value_size = get_declared_type_size_msl(basic_type, is_packed, row_major);
14519 
14520 	uint32_t dimensions = uint32_t(type.array.size());
14521 	assert(dimensions > 0);
14522 	dimensions--;
14523 
14524 	// Multiply together every dimension, except the last one.
14525 	for (uint32_t dim = 0; dim < dimensions; dim++)
14526 	{
14527 		uint32_t array_size = to_array_size_literal(type, dim);
14528 		value_size *= max(array_size, 1u);
14529 	}
14530 
14531 	return value_size;
14532 }
14533 
get_declared_struct_member_array_stride_msl(const SPIRType & type,uint32_t index) const14534 uint32_t CompilerMSL::get_declared_struct_member_array_stride_msl(const SPIRType &type, uint32_t index) const
14535 {
14536 	return get_declared_type_array_stride_msl(get_physical_member_type(type, index),
14537 	                                          member_is_packed_physical_type(type, index),
14538 	                                          has_member_decoration(type.self, index, DecorationRowMajor));
14539 }
14540 
get_declared_input_array_stride_msl(const SPIRType & type,uint32_t index) const14541 uint32_t CompilerMSL::get_declared_input_array_stride_msl(const SPIRType &type, uint32_t index) const
14542 {
14543 	return get_declared_type_array_stride_msl(get_presumed_input_type(type, index), false,
14544 	                                          has_member_decoration(type.self, index, DecorationRowMajor));
14545 }
14546 
get_declared_type_matrix_stride_msl(const SPIRType & type,bool packed,bool row_major) const14547 uint32_t CompilerMSL::get_declared_type_matrix_stride_msl(const SPIRType &type, bool packed, bool row_major) const
14548 {
14549 	// For packed matrices, we just use the size of the vector type.
14550 	// Otherwise, MatrixStride == alignment, which is the size of the underlying vector type.
14551 	if (packed)
14552 		return (type.width / 8) * ((row_major && type.columns > 1) ? type.columns : type.vecsize);
14553 	else
14554 		return get_declared_type_alignment_msl(type, false, row_major);
14555 }
14556 
get_declared_struct_member_matrix_stride_msl(const SPIRType & type,uint32_t index) const14557 uint32_t CompilerMSL::get_declared_struct_member_matrix_stride_msl(const SPIRType &type, uint32_t index) const
14558 {
14559 	return get_declared_type_matrix_stride_msl(get_physical_member_type(type, index),
14560 	                                           member_is_packed_physical_type(type, index),
14561 	                                           has_member_decoration(type.self, index, DecorationRowMajor));
14562 }
14563 
get_declared_input_matrix_stride_msl(const SPIRType & type,uint32_t index) const14564 uint32_t CompilerMSL::get_declared_input_matrix_stride_msl(const SPIRType &type, uint32_t index) const
14565 {
14566 	return get_declared_type_matrix_stride_msl(get_presumed_input_type(type, index), false,
14567 	                                           has_member_decoration(type.self, index, DecorationRowMajor));
14568 }
14569 
get_declared_struct_size_msl(const SPIRType & struct_type,bool ignore_alignment,bool ignore_padding) const14570 uint32_t CompilerMSL::get_declared_struct_size_msl(const SPIRType &struct_type, bool ignore_alignment,
14571                                                    bool ignore_padding) const
14572 {
14573 	// If we have a target size, that is the declared size as well.
14574 	if (!ignore_padding && has_extended_decoration(struct_type.self, SPIRVCrossDecorationPaddingTarget))
14575 		return get_extended_decoration(struct_type.self, SPIRVCrossDecorationPaddingTarget);
14576 
14577 	if (struct_type.member_types.empty())
14578 		return 0;
14579 
14580 	uint32_t mbr_cnt = uint32_t(struct_type.member_types.size());
14581 
14582 	// In MSL, a struct's alignment is equal to the maximum alignment of any of its members.
14583 	uint32_t alignment = 1;
14584 
14585 	if (!ignore_alignment)
14586 	{
14587 		for (uint32_t i = 0; i < mbr_cnt; i++)
14588 		{
14589 			uint32_t mbr_alignment = get_declared_struct_member_alignment_msl(struct_type, i);
14590 			alignment = max(alignment, mbr_alignment);
14591 		}
14592 	}
14593 
14594 	// Last member will always be matched to the final Offset decoration, but size of struct in MSL now depends
14595 	// on physical size in MSL, and the size of the struct itself is then aligned to struct alignment.
14596 	uint32_t spirv_offset = type_struct_member_offset(struct_type, mbr_cnt - 1);
14597 	uint32_t msl_size = spirv_offset + get_declared_struct_member_size_msl(struct_type, mbr_cnt - 1);
14598 	msl_size = (msl_size + alignment - 1) & ~(alignment - 1);
14599 	return msl_size;
14600 }
14601 
14602 // Returns the byte size of a struct member.
get_declared_type_size_msl(const SPIRType & type,bool is_packed,bool row_major) const14603 uint32_t CompilerMSL::get_declared_type_size_msl(const SPIRType &type, bool is_packed, bool row_major) const
14604 {
14605 	switch (type.basetype)
14606 	{
14607 	case SPIRType::Unknown:
14608 	case SPIRType::Void:
14609 	case SPIRType::AtomicCounter:
14610 	case SPIRType::Image:
14611 	case SPIRType::SampledImage:
14612 	case SPIRType::Sampler:
14613 		SPIRV_CROSS_THROW("Querying size of opaque object.");
14614 
14615 	default:
14616 	{
14617 		if (!type.array.empty())
14618 		{
14619 			uint32_t array_size = to_array_size_literal(type);
14620 			return get_declared_type_array_stride_msl(type, is_packed, row_major) * max(array_size, 1u);
14621 		}
14622 
14623 		if (type.basetype == SPIRType::Struct)
14624 			return get_declared_struct_size_msl(type);
14625 
14626 		if (is_packed)
14627 		{
14628 			return type.vecsize * type.columns * (type.width / 8);
14629 		}
14630 		else
14631 		{
14632 			// An unpacked 3-element vector or matrix column is the same memory size as a 4-element.
14633 			uint32_t vecsize = type.vecsize;
14634 			uint32_t columns = type.columns;
14635 
14636 			if (row_major && columns > 1)
14637 				swap(vecsize, columns);
14638 
14639 			if (vecsize == 3)
14640 				vecsize = 4;
14641 
14642 			return vecsize * columns * (type.width / 8);
14643 		}
14644 	}
14645 	}
14646 }
14647 
get_declared_struct_member_size_msl(const SPIRType & type,uint32_t index) const14648 uint32_t CompilerMSL::get_declared_struct_member_size_msl(const SPIRType &type, uint32_t index) const
14649 {
14650 	return get_declared_type_size_msl(get_physical_member_type(type, index),
14651 	                                  member_is_packed_physical_type(type, index),
14652 	                                  has_member_decoration(type.self, index, DecorationRowMajor));
14653 }
14654 
get_declared_input_size_msl(const SPIRType & type,uint32_t index) const14655 uint32_t CompilerMSL::get_declared_input_size_msl(const SPIRType &type, uint32_t index) const
14656 {
14657 	return get_declared_type_size_msl(get_presumed_input_type(type, index), false,
14658 	                                  has_member_decoration(type.self, index, DecorationRowMajor));
14659 }
14660 
14661 // Returns the byte alignment of a type.
get_declared_type_alignment_msl(const SPIRType & type,bool is_packed,bool row_major) const14662 uint32_t CompilerMSL::get_declared_type_alignment_msl(const SPIRType &type, bool is_packed, bool row_major) const
14663 {
14664 	switch (type.basetype)
14665 	{
14666 	case SPIRType::Unknown:
14667 	case SPIRType::Void:
14668 	case SPIRType::AtomicCounter:
14669 	case SPIRType::Image:
14670 	case SPIRType::SampledImage:
14671 	case SPIRType::Sampler:
14672 		SPIRV_CROSS_THROW("Querying alignment of opaque object.");
14673 
14674 	case SPIRType::Double:
14675 		SPIRV_CROSS_THROW("double types are not supported in buffers in MSL.");
14676 
14677 	case SPIRType::Struct:
14678 	{
14679 		// In MSL, a struct's alignment is equal to the maximum alignment of any of its members.
14680 		uint32_t alignment = 1;
14681 		for (uint32_t i = 0; i < type.member_types.size(); i++)
14682 			alignment = max(alignment, uint32_t(get_declared_struct_member_alignment_msl(type, i)));
14683 		return alignment;
14684 	}
14685 
14686 	default:
14687 	{
14688 		if (type.basetype == SPIRType::Int64 && !msl_options.supports_msl_version(2, 3))
14689 			SPIRV_CROSS_THROW("long types in buffers are only supported in MSL 2.3 and above.");
14690 		if (type.basetype == SPIRType::UInt64 && !msl_options.supports_msl_version(2, 3))
14691 			SPIRV_CROSS_THROW("ulong types in buffers are only supported in MSL 2.3 and above.");
14692 		// Alignment of packed type is the same as the underlying component or column size.
14693 		// Alignment of unpacked type is the same as the vector size.
14694 		// Alignment of 3-elements vector is the same as 4-elements (including packed using column).
14695 		if (is_packed)
14696 		{
14697 			// If we have packed_T and friends, the alignment is always scalar.
14698 			return type.width / 8;
14699 		}
14700 		else
14701 		{
14702 			// This is the general rule for MSL. Size == alignment.
14703 			uint32_t vecsize = (row_major && type.columns > 1) ? type.columns : type.vecsize;
14704 			return (type.width / 8) * (vecsize == 3 ? 4 : vecsize);
14705 		}
14706 	}
14707 	}
14708 }
14709 
get_declared_struct_member_alignment_msl(const SPIRType & type,uint32_t index) const14710 uint32_t CompilerMSL::get_declared_struct_member_alignment_msl(const SPIRType &type, uint32_t index) const
14711 {
14712 	return get_declared_type_alignment_msl(get_physical_member_type(type, index),
14713 	                                       member_is_packed_physical_type(type, index),
14714 	                                       has_member_decoration(type.self, index, DecorationRowMajor));
14715 }
14716 
get_declared_input_alignment_msl(const SPIRType & type,uint32_t index) const14717 uint32_t CompilerMSL::get_declared_input_alignment_msl(const SPIRType &type, uint32_t index) const
14718 {
14719 	return get_declared_type_alignment_msl(get_presumed_input_type(type, index), false,
14720 	                                       has_member_decoration(type.self, index, DecorationRowMajor));
14721 }
14722 
skip_argument(uint32_t) const14723 bool CompilerMSL::skip_argument(uint32_t) const
14724 {
14725 	return false;
14726 }
14727 
analyze_sampled_image_usage()14728 void CompilerMSL::analyze_sampled_image_usage()
14729 {
14730 	if (msl_options.swizzle_texture_samples)
14731 	{
14732 		SampledImageScanner scanner(*this);
14733 		traverse_all_reachable_opcodes(get<SPIRFunction>(ir.default_entry_point), scanner);
14734 	}
14735 }
14736 
handle(spv::Op opcode,const uint32_t * args,uint32_t length)14737 bool CompilerMSL::SampledImageScanner::handle(spv::Op opcode, const uint32_t *args, uint32_t length)
14738 {
14739 	switch (opcode)
14740 	{
14741 	case OpLoad:
14742 	case OpImage:
14743 	case OpSampledImage:
14744 	{
14745 		if (length < 3)
14746 			return false;
14747 
14748 		uint32_t result_type = args[0];
14749 		auto &type = compiler.get<SPIRType>(result_type);
14750 		if ((type.basetype != SPIRType::Image && type.basetype != SPIRType::SampledImage) || type.image.sampled != 1)
14751 			return true;
14752 
14753 		uint32_t id = args[1];
14754 		compiler.set<SPIRExpression>(id, "", result_type, true);
14755 		break;
14756 	}
14757 	case OpImageSampleExplicitLod:
14758 	case OpImageSampleProjExplicitLod:
14759 	case OpImageSampleDrefExplicitLod:
14760 	case OpImageSampleProjDrefExplicitLod:
14761 	case OpImageSampleImplicitLod:
14762 	case OpImageSampleProjImplicitLod:
14763 	case OpImageSampleDrefImplicitLod:
14764 	case OpImageSampleProjDrefImplicitLod:
14765 	case OpImageFetch:
14766 	case OpImageGather:
14767 	case OpImageDrefGather:
14768 		compiler.has_sampled_images =
14769 		    compiler.has_sampled_images || compiler.is_sampled_image_type(compiler.expression_type(args[2]));
14770 		compiler.needs_swizzle_buffer_def = compiler.needs_swizzle_buffer_def || compiler.has_sampled_images;
14771 		break;
14772 	default:
14773 		break;
14774 	}
14775 	return true;
14776 }
14777 
14778 // If a needed custom function wasn't added before, add it and force a recompile.
add_spv_func_and_recompile(SPVFuncImpl spv_func)14779 void CompilerMSL::add_spv_func_and_recompile(SPVFuncImpl spv_func)
14780 {
14781 	if (spv_function_implementations.count(spv_func) == 0)
14782 	{
14783 		spv_function_implementations.insert(spv_func);
14784 		suppress_missing_prototypes = true;
14785 		force_recompile();
14786 	}
14787 }
14788 
handle(Op opcode,const uint32_t * args,uint32_t length)14789 bool CompilerMSL::OpCodePreprocessor::handle(Op opcode, const uint32_t *args, uint32_t length)
14790 {
14791 	// Since MSL exists in a single execution scope, function prototype declarations are not
14792 	// needed, and clutter the output. If secondary functions are output (either as a SPIR-V
14793 	// function implementation or as indicated by the presence of OpFunctionCall), then set
14794 	// suppress_missing_prototypes to suppress compiler warnings of missing function prototypes.
14795 
14796 	// Mark if the input requires the implementation of an SPIR-V function that does not exist in Metal.
14797 	SPVFuncImpl spv_func = get_spv_func_impl(opcode, args);
14798 	if (spv_func != SPVFuncImplNone)
14799 	{
14800 		compiler.spv_function_implementations.insert(spv_func);
14801 		suppress_missing_prototypes = true;
14802 	}
14803 
14804 	switch (opcode)
14805 	{
14806 
14807 	case OpFunctionCall:
14808 		suppress_missing_prototypes = true;
14809 		break;
14810 
14811 	// Emulate texture2D atomic operations
14812 	case OpImageTexelPointer:
14813 	{
14814 		auto *var = compiler.maybe_get_backing_variable(args[2]);
14815 		image_pointers[args[1]] = var ? var->self : ID(0);
14816 		break;
14817 	}
14818 
14819 	case OpImageWrite:
14820 		if (!compiler.msl_options.supports_msl_version(2, 2))
14821 			uses_resource_write = true;
14822 		break;
14823 
14824 	case OpStore:
14825 		check_resource_write(args[0]);
14826 		break;
14827 
14828 	// Emulate texture2D atomic operations
14829 	case OpAtomicExchange:
14830 	case OpAtomicCompareExchange:
14831 	case OpAtomicCompareExchangeWeak:
14832 	case OpAtomicIIncrement:
14833 	case OpAtomicIDecrement:
14834 	case OpAtomicIAdd:
14835 	case OpAtomicISub:
14836 	case OpAtomicSMin:
14837 	case OpAtomicUMin:
14838 	case OpAtomicSMax:
14839 	case OpAtomicUMax:
14840 	case OpAtomicAnd:
14841 	case OpAtomicOr:
14842 	case OpAtomicXor:
14843 	{
14844 		uses_atomics = true;
14845 		auto it = image_pointers.find(args[2]);
14846 		if (it != image_pointers.end())
14847 		{
14848 			compiler.atomic_image_vars.insert(it->second);
14849 		}
14850 		check_resource_write(args[2]);
14851 		break;
14852 	}
14853 
14854 	case OpAtomicStore:
14855 	{
14856 		uses_atomics = true;
14857 		auto it = image_pointers.find(args[0]);
14858 		if (it != image_pointers.end())
14859 		{
14860 			compiler.atomic_image_vars.insert(it->second);
14861 		}
14862 		check_resource_write(args[0]);
14863 		break;
14864 	}
14865 
14866 	case OpAtomicLoad:
14867 	{
14868 		uses_atomics = true;
14869 		auto it = image_pointers.find(args[2]);
14870 		if (it != image_pointers.end())
14871 		{
14872 			compiler.atomic_image_vars.insert(it->second);
14873 		}
14874 		break;
14875 	}
14876 
14877 	case OpGroupNonUniformInverseBallot:
14878 		needs_subgroup_invocation_id = true;
14879 		break;
14880 
14881 	case OpGroupNonUniformBallotFindLSB:
14882 	case OpGroupNonUniformBallotFindMSB:
14883 		needs_subgroup_size = true;
14884 		break;
14885 
14886 	case OpGroupNonUniformBallotBitCount:
14887 		if (args[3] == GroupOperationReduce)
14888 			needs_subgroup_size = true;
14889 		else
14890 			needs_subgroup_invocation_id = true;
14891 		break;
14892 
14893 	case OpArrayLength:
14894 	{
14895 		auto *var = compiler.maybe_get_backing_variable(args[2]);
14896 		if (var)
14897 			compiler.buffers_requiring_array_length.insert(var->self);
14898 		break;
14899 	}
14900 
14901 	case OpInBoundsAccessChain:
14902 	case OpAccessChain:
14903 	case OpPtrAccessChain:
14904 	{
14905 		// OpArrayLength might want to know if taking ArrayLength of an array of SSBOs.
14906 		uint32_t result_type = args[0];
14907 		uint32_t id = args[1];
14908 		uint32_t ptr = args[2];
14909 
14910 		compiler.set<SPIRExpression>(id, "", result_type, true);
14911 		compiler.register_read(id, ptr, true);
14912 		compiler.ir.ids[id].set_allow_type_rewrite();
14913 		break;
14914 	}
14915 
14916 	case OpExtInst:
14917 	{
14918 		uint32_t extension_set = args[2];
14919 		if (compiler.get<SPIRExtension>(extension_set).ext == SPIRExtension::GLSL)
14920 		{
14921 			auto op_450 = static_cast<GLSLstd450>(args[3]);
14922 			switch (op_450)
14923 			{
14924 			case GLSLstd450InterpolateAtCentroid:
14925 			case GLSLstd450InterpolateAtSample:
14926 			case GLSLstd450InterpolateAtOffset:
14927 			{
14928 				if (!compiler.msl_options.supports_msl_version(2, 3))
14929 					SPIRV_CROSS_THROW("Pull-model interpolation requires MSL 2.3.");
14930 				// Fragment varyings used with pull-model interpolation need special handling,
14931 				// due to the way pull-model interpolation works in Metal.
14932 				auto *var = compiler.maybe_get_backing_variable(args[4]);
14933 				if (var)
14934 				{
14935 					compiler.pull_model_inputs.insert(var->self);
14936 					auto &var_type = compiler.get_variable_element_type(*var);
14937 					// In addition, if this variable has a 'Sample' decoration, we need the sample ID
14938 					// in order to do default interpolation.
14939 					if (compiler.has_decoration(var->self, DecorationSample))
14940 					{
14941 						needs_sample_id = true;
14942 					}
14943 					else if (var_type.basetype == SPIRType::Struct)
14944 					{
14945 						// Now we need to check each member and see if it has this decoration.
14946 						for (uint32_t i = 0; i < var_type.member_types.size(); ++i)
14947 						{
14948 							if (compiler.has_member_decoration(var_type.self, i, DecorationSample))
14949 							{
14950 								needs_sample_id = true;
14951 								break;
14952 							}
14953 						}
14954 					}
14955 				}
14956 				break;
14957 			}
14958 			default:
14959 				break;
14960 			}
14961 		}
14962 		break;
14963 	}
14964 
14965 	default:
14966 		break;
14967 	}
14968 
14969 	// If it has one, keep track of the instruction's result type, mapped by ID
14970 	uint32_t result_type, result_id;
14971 	if (compiler.instruction_to_result_type(result_type, result_id, opcode, args, length))
14972 		result_types[result_id] = result_type;
14973 
14974 	return true;
14975 }
14976 
14977 // If the variable is a Uniform or StorageBuffer, mark that a resource has been written to.
check_resource_write(uint32_t var_id)14978 void CompilerMSL::OpCodePreprocessor::check_resource_write(uint32_t var_id)
14979 {
14980 	auto *p_var = compiler.maybe_get_backing_variable(var_id);
14981 	StorageClass sc = p_var ? p_var->storage : StorageClassMax;
14982 	if (!compiler.msl_options.supports_msl_version(2, 1) &&
14983 	    (sc == StorageClassUniform || sc == StorageClassStorageBuffer))
14984 		uses_resource_write = true;
14985 }
14986 
14987 // Returns an enumeration of a SPIR-V function that needs to be output for certain Op codes.
get_spv_func_impl(Op opcode,const uint32_t * args)14988 CompilerMSL::SPVFuncImpl CompilerMSL::OpCodePreprocessor::get_spv_func_impl(Op opcode, const uint32_t *args)
14989 {
14990 	switch (opcode)
14991 	{
14992 	case OpFMod:
14993 		return SPVFuncImplMod;
14994 
14995 	case OpFAdd:
14996 	case OpFSub:
14997 		if (compiler.msl_options.invariant_float_math ||
14998 		    compiler.has_decoration(args[1], DecorationNoContraction))
14999 		{
15000 			return opcode == OpFAdd ? SPVFuncImplFAdd : SPVFuncImplFSub;
15001 		}
15002 		break;
15003 
15004 	case OpFMul:
15005 	case OpOuterProduct:
15006 	case OpMatrixTimesVector:
15007 	case OpVectorTimesMatrix:
15008 	case OpMatrixTimesMatrix:
15009 		if (compiler.msl_options.invariant_float_math ||
15010 		    compiler.has_decoration(args[1], DecorationNoContraction))
15011 		{
15012 			return SPVFuncImplFMul;
15013 		}
15014 		break;
15015 
15016 	case OpTypeArray:
15017 	{
15018 		// Allow Metal to use the array<T> template to make arrays a value type
15019 		return SPVFuncImplUnsafeArray;
15020 	}
15021 
15022 	// Emulate texture2D atomic operations
15023 	case OpAtomicExchange:
15024 	case OpAtomicCompareExchange:
15025 	case OpAtomicCompareExchangeWeak:
15026 	case OpAtomicIIncrement:
15027 	case OpAtomicIDecrement:
15028 	case OpAtomicIAdd:
15029 	case OpAtomicISub:
15030 	case OpAtomicSMin:
15031 	case OpAtomicUMin:
15032 	case OpAtomicSMax:
15033 	case OpAtomicUMax:
15034 	case OpAtomicAnd:
15035 	case OpAtomicOr:
15036 	case OpAtomicXor:
15037 	case OpAtomicLoad:
15038 	case OpAtomicStore:
15039 	{
15040 		auto it = image_pointers.find(args[opcode == OpAtomicStore ? 0 : 2]);
15041 		if (it != image_pointers.end())
15042 		{
15043 			uint32_t tid = compiler.get<SPIRVariable>(it->second).basetype;
15044 			if (tid && compiler.get<SPIRType>(tid).image.dim == Dim2D)
15045 				return SPVFuncImplImage2DAtomicCoords;
15046 		}
15047 		break;
15048 	}
15049 
15050 	case OpImageFetch:
15051 	case OpImageRead:
15052 	case OpImageWrite:
15053 	{
15054 		// Retrieve the image type, and if it's a Buffer, emit a texel coordinate function
15055 		uint32_t tid = result_types[args[opcode == OpImageWrite ? 0 : 2]];
15056 		if (tid && compiler.get<SPIRType>(tid).image.dim == DimBuffer && !compiler.msl_options.texture_buffer_native)
15057 			return SPVFuncImplTexelBufferCoords;
15058 		break;
15059 	}
15060 
15061 	case OpExtInst:
15062 	{
15063 		uint32_t extension_set = args[2];
15064 		if (compiler.get<SPIRExtension>(extension_set).ext == SPIRExtension::GLSL)
15065 		{
15066 			auto op_450 = static_cast<GLSLstd450>(args[3]);
15067 			switch (op_450)
15068 			{
15069 			case GLSLstd450Radians:
15070 				return SPVFuncImplRadians;
15071 			case GLSLstd450Degrees:
15072 				return SPVFuncImplDegrees;
15073 			case GLSLstd450FindILsb:
15074 				return SPVFuncImplFindILsb;
15075 			case GLSLstd450FindSMsb:
15076 				return SPVFuncImplFindSMsb;
15077 			case GLSLstd450FindUMsb:
15078 				return SPVFuncImplFindUMsb;
15079 			case GLSLstd450SSign:
15080 				return SPVFuncImplSSign;
15081 			case GLSLstd450Reflect:
15082 			{
15083 				auto &type = compiler.get<SPIRType>(args[0]);
15084 				if (type.vecsize == 1)
15085 					return SPVFuncImplReflectScalar;
15086 				break;
15087 			}
15088 			case GLSLstd450Refract:
15089 			{
15090 				auto &type = compiler.get<SPIRType>(args[0]);
15091 				if (type.vecsize == 1)
15092 					return SPVFuncImplRefractScalar;
15093 				break;
15094 			}
15095 			case GLSLstd450FaceForward:
15096 			{
15097 				auto &type = compiler.get<SPIRType>(args[0]);
15098 				if (type.vecsize == 1)
15099 					return SPVFuncImplFaceForwardScalar;
15100 				break;
15101 			}
15102 			case GLSLstd450MatrixInverse:
15103 			{
15104 				auto &mat_type = compiler.get<SPIRType>(args[0]);
15105 				switch (mat_type.columns)
15106 				{
15107 				case 2:
15108 					return SPVFuncImplInverse2x2;
15109 				case 3:
15110 					return SPVFuncImplInverse3x3;
15111 				case 4:
15112 					return SPVFuncImplInverse4x4;
15113 				default:
15114 					break;
15115 				}
15116 				break;
15117 			}
15118 			default:
15119 				break;
15120 			}
15121 		}
15122 		break;
15123 	}
15124 
15125 	case OpGroupNonUniformBroadcast:
15126 		return SPVFuncImplSubgroupBroadcast;
15127 
15128 	case OpGroupNonUniformBroadcastFirst:
15129 		return SPVFuncImplSubgroupBroadcastFirst;
15130 
15131 	case OpGroupNonUniformBallot:
15132 		return SPVFuncImplSubgroupBallot;
15133 
15134 	case OpGroupNonUniformInverseBallot:
15135 	case OpGroupNonUniformBallotBitExtract:
15136 		return SPVFuncImplSubgroupBallotBitExtract;
15137 
15138 	case OpGroupNonUniformBallotFindLSB:
15139 		return SPVFuncImplSubgroupBallotFindLSB;
15140 
15141 	case OpGroupNonUniformBallotFindMSB:
15142 		return SPVFuncImplSubgroupBallotFindMSB;
15143 
15144 	case OpGroupNonUniformBallotBitCount:
15145 		return SPVFuncImplSubgroupBallotBitCount;
15146 
15147 	case OpGroupNonUniformAllEqual:
15148 		return SPVFuncImplSubgroupAllEqual;
15149 
15150 	case OpGroupNonUniformShuffle:
15151 		return SPVFuncImplSubgroupShuffle;
15152 
15153 	case OpGroupNonUniformShuffleXor:
15154 		return SPVFuncImplSubgroupShuffleXor;
15155 
15156 	case OpGroupNonUniformShuffleUp:
15157 		return SPVFuncImplSubgroupShuffleUp;
15158 
15159 	case OpGroupNonUniformShuffleDown:
15160 		return SPVFuncImplSubgroupShuffleDown;
15161 
15162 	case OpGroupNonUniformQuadBroadcast:
15163 		return SPVFuncImplQuadBroadcast;
15164 
15165 	case OpGroupNonUniformQuadSwap:
15166 		return SPVFuncImplQuadSwap;
15167 
15168 	default:
15169 		break;
15170 	}
15171 	return SPVFuncImplNone;
15172 }
15173 
15174 // Sort both type and meta member content based on builtin status (put builtins at end),
15175 // then by the required sorting aspect.
sort()15176 void CompilerMSL::MemberSorter::sort()
15177 {
15178 	// Create a temporary array of consecutive member indices and sort it based on how
15179 	// the members should be reordered, based on builtin and sorting aspect meta info.
15180 	size_t mbr_cnt = type.member_types.size();
15181 	SmallVector<uint32_t> mbr_idxs(mbr_cnt);
15182 	std::iota(mbr_idxs.begin(), mbr_idxs.end(), 0); // Fill with consecutive indices
15183 	std::stable_sort(mbr_idxs.begin(), mbr_idxs.end(), *this); // Sort member indices based on sorting aspect
15184 
15185 	bool sort_is_identity = true;
15186 	for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
15187 	{
15188 		if (mbr_idx != mbr_idxs[mbr_idx])
15189 		{
15190 			sort_is_identity = false;
15191 			break;
15192 		}
15193 	}
15194 
15195 	if (sort_is_identity)
15196 		return;
15197 
15198 	if (meta.members.size() < type.member_types.size())
15199 	{
15200 		// This should never trigger in normal circumstances, but to be safe.
15201 		meta.members.resize(type.member_types.size());
15202 	}
15203 
15204 	// Move type and meta member info to the order defined by the sorted member indices.
15205 	// This is done by creating temporary copies of both member types and meta, and then
15206 	// copying back to the original content at the sorted indices.
15207 	auto mbr_types_cpy = type.member_types;
15208 	auto mbr_meta_cpy = meta.members;
15209 	for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
15210 	{
15211 		type.member_types[mbr_idx] = mbr_types_cpy[mbr_idxs[mbr_idx]];
15212 		meta.members[mbr_idx] = mbr_meta_cpy[mbr_idxs[mbr_idx]];
15213 	}
15214 
15215 	if (sort_aspect == SortAspect::Offset)
15216 	{
15217 		// If we're sorting by Offset, this might affect user code which accesses a buffer block.
15218 		// We will need to redirect member indices from one index to sorted index.
15219 		type.member_type_index_redirection = std::move(mbr_idxs);
15220 	}
15221 }
15222 
operator ()(uint32_t mbr_idx1,uint32_t mbr_idx2)15223 bool CompilerMSL::MemberSorter::operator()(uint32_t mbr_idx1, uint32_t mbr_idx2)
15224 {
15225 	auto &mbr_meta1 = meta.members[mbr_idx1];
15226 	auto &mbr_meta2 = meta.members[mbr_idx2];
15227 
15228 	if (sort_aspect == LocationThenBuiltInType)
15229 	{
15230 		// Sort first by builtin status (put builtins at end), then by the sorting aspect.
15231 		if (mbr_meta1.builtin != mbr_meta2.builtin)
15232 			return mbr_meta2.builtin;
15233 		else if (mbr_meta1.builtin)
15234 			return mbr_meta1.builtin_type < mbr_meta2.builtin_type;
15235 		else if (mbr_meta1.location == mbr_meta2.location)
15236 			return mbr_meta1.component < mbr_meta2.component;
15237 		else
15238 			return mbr_meta1.location < mbr_meta2.location;
15239 	}
15240 	else
15241 		return mbr_meta1.offset < mbr_meta2.offset;
15242 }
15243 
MemberSorter(SPIRType & t,Meta & m,SortAspect sa)15244 CompilerMSL::MemberSorter::MemberSorter(SPIRType &t, Meta &m, SortAspect sa)
15245     : type(t)
15246     , meta(m)
15247     , sort_aspect(sa)
15248 {
15249 	// Ensure enough meta info is available
15250 	meta.members.resize(max(type.member_types.size(), meta.members.size()));
15251 }
15252 
remap_constexpr_sampler(VariableID id,const MSLConstexprSampler & sampler)15253 void CompilerMSL::remap_constexpr_sampler(VariableID id, const MSLConstexprSampler &sampler)
15254 {
15255 	auto &type = get<SPIRType>(get<SPIRVariable>(id).basetype);
15256 	if (type.basetype != SPIRType::SampledImage && type.basetype != SPIRType::Sampler)
15257 		SPIRV_CROSS_THROW("Can only remap SampledImage and Sampler type.");
15258 	if (!type.array.empty())
15259 		SPIRV_CROSS_THROW("Can not remap array of samplers.");
15260 	constexpr_samplers_by_id[id] = sampler;
15261 }
15262 
remap_constexpr_sampler_by_binding(uint32_t desc_set,uint32_t binding,const MSLConstexprSampler & sampler)15263 void CompilerMSL::remap_constexpr_sampler_by_binding(uint32_t desc_set, uint32_t binding,
15264                                                      const MSLConstexprSampler &sampler)
15265 {
15266 	constexpr_samplers_by_binding[{ desc_set, binding }] = sampler;
15267 }
15268 
cast_from_builtin_load(uint32_t source_id,std::string & expr,const SPIRType & expr_type)15269 void CompilerMSL::cast_from_builtin_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type)
15270 {
15271 	auto *var = maybe_get_backing_variable(source_id);
15272 	if (var)
15273 		source_id = var->self;
15274 
15275 	// Only interested in standalone builtin variables.
15276 	if (!has_decoration(source_id, DecorationBuiltIn))
15277 		return;
15278 
15279 	auto builtin = static_cast<BuiltIn>(get_decoration(source_id, DecorationBuiltIn));
15280 	auto expected_type = expr_type.basetype;
15281 	auto expected_width = expr_type.width;
15282 	switch (builtin)
15283 	{
15284 	case BuiltInGlobalInvocationId:
15285 	case BuiltInLocalInvocationId:
15286 	case BuiltInWorkgroupId:
15287 	case BuiltInLocalInvocationIndex:
15288 	case BuiltInWorkgroupSize:
15289 	case BuiltInNumWorkgroups:
15290 	case BuiltInLayer:
15291 	case BuiltInViewportIndex:
15292 	case BuiltInFragStencilRefEXT:
15293 	case BuiltInPrimitiveId:
15294 	case BuiltInSubgroupSize:
15295 	case BuiltInSubgroupLocalInvocationId:
15296 	case BuiltInViewIndex:
15297 	case BuiltInVertexIndex:
15298 	case BuiltInInstanceIndex:
15299 	case BuiltInBaseInstance:
15300 	case BuiltInBaseVertex:
15301 		expected_type = SPIRType::UInt;
15302 		expected_width = 32;
15303 		break;
15304 
15305 	case BuiltInTessLevelInner:
15306 	case BuiltInTessLevelOuter:
15307 		if (get_execution_model() == ExecutionModelTessellationControl)
15308 		{
15309 			expected_type = SPIRType::Half;
15310 			expected_width = 16;
15311 		}
15312 		break;
15313 
15314 	default:
15315 		break;
15316 	}
15317 
15318 	if (expected_type != expr_type.basetype)
15319 	{
15320 		if (!expr_type.array.empty() && (builtin == BuiltInTessLevelInner || builtin == BuiltInTessLevelOuter))
15321 		{
15322 			// Triggers when loading TessLevel directly as an array.
15323 			// Need explicit padding + cast.
15324 			auto wrap_expr = join(type_to_glsl(expr_type), "({ ");
15325 
15326 			uint32_t array_size = get_physical_tess_level_array_size(builtin);
15327 			for (uint32_t i = 0; i < array_size; i++)
15328 			{
15329 				if (array_size > 1)
15330 					wrap_expr += join("float(", expr, "[", i, "])");
15331 				else
15332 					wrap_expr += join("float(", expr, ")");
15333 				if (i + 1 < array_size)
15334 					wrap_expr += ", ";
15335 			}
15336 
15337 			if (get_execution_mode_bitset().get(ExecutionModeTriangles))
15338 				wrap_expr += ", 0.0";
15339 
15340 			wrap_expr += " })";
15341 			expr = std::move(wrap_expr);
15342 		}
15343 		else
15344 		{
15345 			// These are of different widths, so we cannot do a straight bitcast.
15346 			if (expected_width != expr_type.width)
15347 				expr = join(type_to_glsl(expr_type), "(", expr, ")");
15348 			else
15349 				expr = bitcast_expression(expr_type, expected_type, expr);
15350 		}
15351 	}
15352 
15353 	if (builtin == BuiltInTessCoord && get_entry_point().flags.get(ExecutionModeQuads) && expr_type.vecsize == 3)
15354 	{
15355 		// In SPIR-V, this is always a vec3, even for quads. In Metal, though, it's a float2 for quads.
15356 		// The code is expecting a float3, so we need to widen this.
15357 		expr = join("float3(", expr, ", 0)");
15358 	}
15359 }
15360 
cast_to_builtin_store(uint32_t target_id,std::string & expr,const SPIRType & expr_type)15361 void CompilerMSL::cast_to_builtin_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type)
15362 {
15363 	auto *var = maybe_get_backing_variable(target_id);
15364 	if (var)
15365 		target_id = var->self;
15366 
15367 	// Only interested in standalone builtin variables.
15368 	if (!has_decoration(target_id, DecorationBuiltIn))
15369 		return;
15370 
15371 	auto builtin = static_cast<BuiltIn>(get_decoration(target_id, DecorationBuiltIn));
15372 	auto expected_type = expr_type.basetype;
15373 	auto expected_width = expr_type.width;
15374 	switch (builtin)
15375 	{
15376 	case BuiltInLayer:
15377 	case BuiltInViewportIndex:
15378 	case BuiltInFragStencilRefEXT:
15379 	case BuiltInPrimitiveId:
15380 	case BuiltInViewIndex:
15381 		expected_type = SPIRType::UInt;
15382 		expected_width = 32;
15383 		break;
15384 
15385 	case BuiltInTessLevelInner:
15386 	case BuiltInTessLevelOuter:
15387 		expected_type = SPIRType::Half;
15388 		expected_width = 16;
15389 		break;
15390 
15391 	default:
15392 		break;
15393 	}
15394 
15395 	if (expected_type != expr_type.basetype)
15396 	{
15397 		if (expected_width != expr_type.width)
15398 		{
15399 			// These are of different widths, so we cannot do a straight bitcast.
15400 			auto type = expr_type;
15401 			type.basetype = expected_type;
15402 			type.width = expected_width;
15403 			expr = join(type_to_glsl(type), "(", expr, ")");
15404 		}
15405 		else
15406 		{
15407 			auto type = expr_type;
15408 			type.basetype = expected_type;
15409 			expr = bitcast_expression(type, expr_type.basetype, expr);
15410 		}
15411 	}
15412 }
15413 
to_initializer_expression(const SPIRVariable & var)15414 string CompilerMSL::to_initializer_expression(const SPIRVariable &var)
15415 {
15416 	// We risk getting an array initializer here with MSL. If we have an array.
15417 	// FIXME: We cannot handle non-constant arrays being initialized.
15418 	// We will need to inject spvArrayCopy here somehow ...
15419 	auto &type = get<SPIRType>(var.basetype);
15420 	string expr;
15421 	if (ir.ids[var.initializer].get_type() == TypeConstant &&
15422 	    (!type.array.empty() || type.basetype == SPIRType::Struct))
15423 		expr = constant_expression(get<SPIRConstant>(var.initializer));
15424 	else
15425 		expr = CompilerGLSL::to_initializer_expression(var);
15426 	// If the initializer has more vector components than the variable, add a swizzle.
15427 	// FIXME: This can't handle arrays or structs.
15428 	auto &init_type = expression_type(var.initializer);
15429 	if (type.array.empty() && type.basetype != SPIRType::Struct && init_type.vecsize > type.vecsize)
15430 		expr = enclose_expression(expr + vector_swizzle(type.vecsize, 0));
15431 	return expr;
15432 }
15433 
to_zero_initialized_expression(uint32_t)15434 string CompilerMSL::to_zero_initialized_expression(uint32_t)
15435 {
15436 	return "{}";
15437 }
15438 
descriptor_set_is_argument_buffer(uint32_t desc_set) const15439 bool CompilerMSL::descriptor_set_is_argument_buffer(uint32_t desc_set) const
15440 {
15441 	if (!msl_options.argument_buffers)
15442 		return false;
15443 	if (desc_set >= kMaxArgumentBuffers)
15444 		return false;
15445 
15446 	return (argument_buffer_discrete_mask & (1u << desc_set)) == 0;
15447 }
15448 
is_supported_argument_buffer_type(const SPIRType & type) const15449 bool CompilerMSL::is_supported_argument_buffer_type(const SPIRType &type) const
15450 {
15451 	// Very specifically, image load-store in argument buffers are disallowed on MSL on iOS.
15452 	// But we won't know when the argument buffer is encoded whether this image will have
15453 	// a NonWritable decoration. So just use discrete arguments for all storage images
15454 	// on iOS.
15455 	bool is_storage_image = type.basetype == SPIRType::Image && type.image.sampled == 2;
15456 	bool is_supported_type = !msl_options.is_ios() || !is_storage_image;
15457 	return !type_is_msl_framebuffer_fetch(type) && is_supported_type;
15458 }
15459 
analyze_argument_buffers()15460 void CompilerMSL::analyze_argument_buffers()
15461 {
15462 	// Gather all used resources and sort them out into argument buffers.
15463 	// Each argument buffer corresponds to a descriptor set in SPIR-V.
15464 	// The [[id(N)]] values used correspond to the resource mapping we have for MSL.
15465 	// Otherwise, the binding number is used, but this is generally not safe some types like
15466 	// combined image samplers and arrays of resources. Metal needs different indices here,
15467 	// while SPIR-V can have one descriptor set binding. To use argument buffers in practice,
15468 	// you will need to use the remapping from the API.
15469 	for (auto &id : argument_buffer_ids)
15470 		id = 0;
15471 
15472 	// Output resources, sorted by resource index & type.
15473 	struct Resource
15474 	{
15475 		SPIRVariable *var;
15476 		string name;
15477 		SPIRType::BaseType basetype;
15478 		uint32_t index;
15479 		uint32_t plane;
15480 	};
15481 	SmallVector<Resource> resources_in_set[kMaxArgumentBuffers];
15482 	SmallVector<uint32_t> inline_block_vars;
15483 
15484 	bool set_needs_swizzle_buffer[kMaxArgumentBuffers] = {};
15485 	bool set_needs_buffer_sizes[kMaxArgumentBuffers] = {};
15486 	bool needs_buffer_sizes = false;
15487 
15488 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t self, SPIRVariable &var) {
15489 		if ((var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant ||
15490 		     var.storage == StorageClassStorageBuffer) &&
15491 		    !is_hidden_variable(var))
15492 		{
15493 			uint32_t desc_set = get_decoration(self, DecorationDescriptorSet);
15494 			// Ignore if it's part of a push descriptor set.
15495 			if (!descriptor_set_is_argument_buffer(desc_set))
15496 				return;
15497 
15498 			uint32_t var_id = var.self;
15499 			auto &type = get_variable_data_type(var);
15500 
15501 			if (desc_set >= kMaxArgumentBuffers)
15502 				SPIRV_CROSS_THROW("Descriptor set index is out of range.");
15503 
15504 			const MSLConstexprSampler *constexpr_sampler = nullptr;
15505 			if (type.basetype == SPIRType::SampledImage || type.basetype == SPIRType::Sampler)
15506 			{
15507 				constexpr_sampler = find_constexpr_sampler(var_id);
15508 				if (constexpr_sampler)
15509 				{
15510 					// Mark this ID as a constexpr sampler for later in case it came from set/bindings.
15511 					constexpr_samplers_by_id[var_id] = *constexpr_sampler;
15512 				}
15513 			}
15514 
15515 			uint32_t binding = get_decoration(var_id, DecorationBinding);
15516 			if (type.basetype == SPIRType::SampledImage)
15517 			{
15518 				add_resource_name(var_id);
15519 
15520 				uint32_t plane_count = 1;
15521 				if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
15522 					plane_count = constexpr_sampler->planes;
15523 
15524 				for (uint32_t i = 0; i < plane_count; i++)
15525 				{
15526 					uint32_t image_resource_index = get_metal_resource_index(var, SPIRType::Image, i);
15527 					resources_in_set[desc_set].push_back(
15528 					    { &var, to_name(var_id), SPIRType::Image, image_resource_index, i });
15529 				}
15530 
15531 				if (type.image.dim != DimBuffer && !constexpr_sampler)
15532 				{
15533 					uint32_t sampler_resource_index = get_metal_resource_index(var, SPIRType::Sampler);
15534 					resources_in_set[desc_set].push_back(
15535 					    { &var, to_sampler_expression(var_id), SPIRType::Sampler, sampler_resource_index, 0 });
15536 				}
15537 			}
15538 			else if (inline_uniform_blocks.count(SetBindingPair{ desc_set, binding }))
15539 			{
15540 				inline_block_vars.push_back(var_id);
15541 			}
15542 			else if (!constexpr_sampler && is_supported_argument_buffer_type(type))
15543 			{
15544 				// constexpr samplers are not declared as resources.
15545 				// Inline uniform blocks are always emitted at the end.
15546 				add_resource_name(var_id);
15547 				resources_in_set[desc_set].push_back(
15548 					{ &var, to_name(var_id), type.basetype, get_metal_resource_index(var, type.basetype), 0 });
15549 
15550 				// Emulate texture2D atomic operations
15551 				if (atomic_image_vars.count(var.self))
15552 				{
15553 					uint32_t buffer_resource_index = get_metal_resource_index(var, SPIRType::AtomicCounter, 0);
15554 					resources_in_set[desc_set].push_back(
15555 						{ &var, to_name(var_id) + "_atomic", SPIRType::Struct, buffer_resource_index, 0 });
15556 				}
15557 			}
15558 
15559 			// Check if this descriptor set needs a swizzle buffer.
15560 			if (needs_swizzle_buffer_def && is_sampled_image_type(type))
15561 				set_needs_swizzle_buffer[desc_set] = true;
15562 			else if (buffers_requiring_array_length.count(var_id) != 0)
15563 			{
15564 				set_needs_buffer_sizes[desc_set] = true;
15565 				needs_buffer_sizes = true;
15566 			}
15567 		}
15568 	});
15569 
15570 	if (needs_swizzle_buffer_def || needs_buffer_sizes)
15571 	{
15572 		uint32_t uint_ptr_type_id = 0;
15573 
15574 		// We might have to add a swizzle buffer resource to the set.
15575 		for (uint32_t desc_set = 0; desc_set < kMaxArgumentBuffers; desc_set++)
15576 		{
15577 			if (!set_needs_swizzle_buffer[desc_set] && !set_needs_buffer_sizes[desc_set])
15578 				continue;
15579 
15580 			if (uint_ptr_type_id == 0)
15581 			{
15582 				uint_ptr_type_id = ir.increase_bound_by(1);
15583 
15584 				// Create a buffer to hold extra data, including the swizzle constants.
15585 				SPIRType uint_type_pointer = get_uint_type();
15586 				uint_type_pointer.pointer = true;
15587 				uint_type_pointer.pointer_depth++;
15588 				uint_type_pointer.parent_type = get_uint_type_id();
15589 				uint_type_pointer.storage = StorageClassUniform;
15590 				set<SPIRType>(uint_ptr_type_id, uint_type_pointer);
15591 				set_decoration(uint_ptr_type_id, DecorationArrayStride, 4);
15592 			}
15593 
15594 			if (set_needs_swizzle_buffer[desc_set])
15595 			{
15596 				uint32_t var_id = ir.increase_bound_by(1);
15597 				auto &var = set<SPIRVariable>(var_id, uint_ptr_type_id, StorageClassUniformConstant);
15598 				set_name(var_id, "spvSwizzleConstants");
15599 				set_decoration(var_id, DecorationDescriptorSet, desc_set);
15600 				set_decoration(var_id, DecorationBinding, kSwizzleBufferBinding);
15601 				resources_in_set[desc_set].push_back(
15602 				    { &var, to_name(var_id), SPIRType::UInt, get_metal_resource_index(var, SPIRType::UInt), 0 });
15603 			}
15604 
15605 			if (set_needs_buffer_sizes[desc_set])
15606 			{
15607 				uint32_t var_id = ir.increase_bound_by(1);
15608 				auto &var = set<SPIRVariable>(var_id, uint_ptr_type_id, StorageClassUniformConstant);
15609 				set_name(var_id, "spvBufferSizeConstants");
15610 				set_decoration(var_id, DecorationDescriptorSet, desc_set);
15611 				set_decoration(var_id, DecorationBinding, kBufferSizeBufferBinding);
15612 				resources_in_set[desc_set].push_back(
15613 				    { &var, to_name(var_id), SPIRType::UInt, get_metal_resource_index(var, SPIRType::UInt), 0 });
15614 			}
15615 		}
15616 	}
15617 
15618 	// Now add inline uniform blocks.
15619 	for (uint32_t var_id : inline_block_vars)
15620 	{
15621 		auto &var = get<SPIRVariable>(var_id);
15622 		uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet);
15623 		add_resource_name(var_id);
15624 		resources_in_set[desc_set].push_back(
15625 		    { &var, to_name(var_id), SPIRType::Struct, get_metal_resource_index(var, SPIRType::Struct), 0 });
15626 	}
15627 
15628 	for (uint32_t desc_set = 0; desc_set < kMaxArgumentBuffers; desc_set++)
15629 	{
15630 		auto &resources = resources_in_set[desc_set];
15631 		if (resources.empty())
15632 			continue;
15633 
15634 		assert(descriptor_set_is_argument_buffer(desc_set));
15635 
15636 		uint32_t next_id = ir.increase_bound_by(3);
15637 		uint32_t type_id = next_id + 1;
15638 		uint32_t ptr_type_id = next_id + 2;
15639 		argument_buffer_ids[desc_set] = next_id;
15640 
15641 		auto &buffer_type = set<SPIRType>(type_id);
15642 
15643 		buffer_type.basetype = SPIRType::Struct;
15644 
15645 		if ((argument_buffer_device_storage_mask & (1u << desc_set)) != 0)
15646 		{
15647 			buffer_type.storage = StorageClassStorageBuffer;
15648 			// Make sure the argument buffer gets marked as const device.
15649 			set_decoration(next_id, DecorationNonWritable);
15650 			// Need to mark the type as a Block to enable this.
15651 			set_decoration(type_id, DecorationBlock);
15652 		}
15653 		else
15654 			buffer_type.storage = StorageClassUniform;
15655 
15656 		set_name(type_id, join("spvDescriptorSetBuffer", desc_set));
15657 
15658 		auto &ptr_type = set<SPIRType>(ptr_type_id);
15659 		ptr_type = buffer_type;
15660 		ptr_type.pointer = true;
15661 		ptr_type.pointer_depth++;
15662 		ptr_type.parent_type = type_id;
15663 
15664 		uint32_t buffer_variable_id = next_id;
15665 		set<SPIRVariable>(buffer_variable_id, ptr_type_id, StorageClassUniform);
15666 		set_name(buffer_variable_id, join("spvDescriptorSet", desc_set));
15667 
15668 		// Ids must be emitted in ID order.
15669 		sort(begin(resources), end(resources), [&](const Resource &lhs, const Resource &rhs) -> bool {
15670 			return tie(lhs.index, lhs.basetype) < tie(rhs.index, rhs.basetype);
15671 		});
15672 
15673 		uint32_t member_index = 0;
15674 		uint32_t next_arg_buff_index = 0;
15675 		for (auto &resource : resources)
15676 		{
15677 			auto &var = *resource.var;
15678 			auto &type = get_variable_data_type(var);
15679 
15680 			// If needed, synthesize and add padding members.
15681 			// member_index and next_arg_buff_index are incremented when padding members are added.
15682 			if (msl_options.pad_argument_buffer_resources)
15683 			{
15684 				while (resource.index > next_arg_buff_index)
15685 				{
15686 					auto &rez_bind = get_argument_buffer_resource(desc_set, next_arg_buff_index);
15687 					switch (rez_bind.basetype)
15688 					{
15689 					case SPIRType::Void:
15690 					case SPIRType::Boolean:
15691 					case SPIRType::SByte:
15692 					case SPIRType::UByte:
15693 					case SPIRType::Short:
15694 					case SPIRType::UShort:
15695 					case SPIRType::Int:
15696 					case SPIRType::UInt:
15697 					case SPIRType::Int64:
15698 					case SPIRType::UInt64:
15699 					case SPIRType::AtomicCounter:
15700 					case SPIRType::Half:
15701 					case SPIRType::Float:
15702 					case SPIRType::Double:
15703 						add_argument_buffer_padding_buffer_type(buffer_type, member_index, next_arg_buff_index, rez_bind);
15704 						break;
15705 					case SPIRType::Image:
15706 						add_argument_buffer_padding_image_type(buffer_type, member_index, next_arg_buff_index, rez_bind);
15707 						break;
15708 					case SPIRType::Sampler:
15709 						add_argument_buffer_padding_sampler_type(buffer_type, member_index, next_arg_buff_index, rez_bind);
15710 						break;
15711 					case SPIRType::SampledImage:
15712 						if (next_arg_buff_index == rez_bind.msl_sampler)
15713 							add_argument_buffer_padding_sampler_type(buffer_type, member_index, next_arg_buff_index, rez_bind);
15714 						else
15715 							add_argument_buffer_padding_image_type(buffer_type, member_index, next_arg_buff_index, rez_bind);
15716 						break;
15717 					default:
15718 						break;
15719 					}
15720 				}
15721 
15722 				// Adjust the number of slots consumed by current member itself.
15723 				// If actual member is an array, allow runtime array resolution as well.
15724 				uint32_t elem_cnt = type.array.empty() ? 1 : to_array_size_literal(type);
15725 				if (elem_cnt == 0)
15726 					elem_cnt = get_resource_array_size(var.self);
15727 
15728 				next_arg_buff_index += elem_cnt;
15729 			}
15730 
15731 			string mbr_name = ensure_valid_name(resource.name, "m");
15732 			if (resource.plane > 0)
15733 				mbr_name += join(plane_name_suffix, resource.plane);
15734 			set_member_name(buffer_type.self, member_index, mbr_name);
15735 
15736 			if (resource.basetype == SPIRType::Sampler && type.basetype != SPIRType::Sampler)
15737 			{
15738 				// Have to synthesize a sampler type here.
15739 
15740 				bool type_is_array = !type.array.empty();
15741 				uint32_t sampler_type_id = ir.increase_bound_by(type_is_array ? 2 : 1);
15742 				auto &new_sampler_type = set<SPIRType>(sampler_type_id);
15743 				new_sampler_type.basetype = SPIRType::Sampler;
15744 				new_sampler_type.storage = StorageClassUniformConstant;
15745 
15746 				if (type_is_array)
15747 				{
15748 					uint32_t sampler_type_array_id = sampler_type_id + 1;
15749 					auto &sampler_type_array = set<SPIRType>(sampler_type_array_id);
15750 					sampler_type_array = new_sampler_type;
15751 					sampler_type_array.array = type.array;
15752 					sampler_type_array.array_size_literal = type.array_size_literal;
15753 					sampler_type_array.parent_type = sampler_type_id;
15754 					buffer_type.member_types.push_back(sampler_type_array_id);
15755 				}
15756 				else
15757 					buffer_type.member_types.push_back(sampler_type_id);
15758 			}
15759 			else
15760 			{
15761 				uint32_t binding = get_decoration(var.self, DecorationBinding);
15762 				SetBindingPair pair = { desc_set, binding };
15763 
15764 				if (resource.basetype == SPIRType::Image || resource.basetype == SPIRType::Sampler ||
15765 				    resource.basetype == SPIRType::SampledImage)
15766 				{
15767 					// Drop pointer information when we emit the resources into a struct.
15768 					buffer_type.member_types.push_back(get_variable_data_type_id(var));
15769 					if (resource.plane == 0)
15770 						set_qualified_name(var.self, join(to_name(buffer_variable_id), ".", mbr_name));
15771 				}
15772 				else if (buffers_requiring_dynamic_offset.count(pair))
15773 				{
15774 					// Don't set the qualified name here; we'll define a variable holding the corrected buffer address later.
15775 					buffer_type.member_types.push_back(var.basetype);
15776 					buffers_requiring_dynamic_offset[pair].second = var.self;
15777 				}
15778 				else if (inline_uniform_blocks.count(pair))
15779 				{
15780 					// Put the buffer block itself into the argument buffer.
15781 					buffer_type.member_types.push_back(get_variable_data_type_id(var));
15782 					set_qualified_name(var.self, join(to_name(buffer_variable_id), ".", mbr_name));
15783 				}
15784 				else if (atomic_image_vars.count(var.self))
15785 				{
15786 					// Emulate texture2D atomic operations.
15787 					// Don't set the qualified name: it's already set for this variable,
15788 					// and the code that references the buffer manually appends "_atomic"
15789 					// to the name.
15790 					uint32_t offset = ir.increase_bound_by(2);
15791 					uint32_t atomic_type_id = offset;
15792 					uint32_t type_ptr_id = offset + 1;
15793 
15794 					SPIRType atomic_type;
15795 					atomic_type.basetype = SPIRType::AtomicCounter;
15796 					atomic_type.width = 32;
15797 					atomic_type.vecsize = 1;
15798 					set<SPIRType>(atomic_type_id, atomic_type);
15799 
15800 					atomic_type.pointer = true;
15801 					atomic_type.pointer_depth++;
15802 					atomic_type.parent_type = atomic_type_id;
15803 					atomic_type.storage = StorageClassStorageBuffer;
15804 					auto &atomic_ptr_type = set<SPIRType>(type_ptr_id, atomic_type);
15805 					atomic_ptr_type.self = atomic_type_id;
15806 
15807 					buffer_type.member_types.push_back(type_ptr_id);
15808 				}
15809 				else
15810 				{
15811 					// Resources will be declared as pointers not references, so automatically dereference as appropriate.
15812 					buffer_type.member_types.push_back(var.basetype);
15813 					if (type.array.empty())
15814 						set_qualified_name(var.self, join("(*", to_name(buffer_variable_id), ".", mbr_name, ")"));
15815 					else
15816 						set_qualified_name(var.self, join(to_name(buffer_variable_id), ".", mbr_name));
15817 				}
15818 			}
15819 
15820 			set_extended_member_decoration(buffer_type.self, member_index, SPIRVCrossDecorationResourceIndexPrimary,
15821 			                               resource.index);
15822 			set_extended_member_decoration(buffer_type.self, member_index, SPIRVCrossDecorationInterfaceOrigID,
15823 			                               var.self);
15824 			member_index++;
15825 		}
15826 	}
15827 }
15828 
15829 // Return the resource type of the app-provided resources for the descriptor set,
15830 // that matches the resource index of the argument buffer index.
15831 // This is a two-step lookup, first lookup the resource binding number from the argument buffer index,
15832 // then lookup the resource binding using the binding number.
get_argument_buffer_resource(uint32_t desc_set,uint32_t arg_idx)15833 MSLResourceBinding &CompilerMSL::get_argument_buffer_resource(uint32_t desc_set, uint32_t arg_idx)
15834 {
15835 	auto stage = get_entry_point().model;
15836 	StageSetBinding arg_idx_tuple = { stage, desc_set, arg_idx };
15837 	auto arg_itr = resource_arg_buff_idx_to_binding_number.find(arg_idx_tuple);
15838 	if (arg_itr != end(resource_arg_buff_idx_to_binding_number))
15839 	{
15840 		StageSetBinding bind_tuple = { stage, desc_set, arg_itr->second };
15841 		auto bind_itr = resource_bindings.find(bind_tuple);
15842 		if (bind_itr != end(resource_bindings))
15843 			return bind_itr->second.first;
15844 	}
15845 	SPIRV_CROSS_THROW("Argument buffer resource base type could not be determined. When padding argument buffer "
15846 	                  "elements, all descriptor set resources must be supplied with a base type by the app.");
15847 }
15848 
15849 // Adds an argument buffer padding argument buffer type as one or more members of the struct type at the member index.
15850 // Metal does not support arrays of buffers, so these are emitted as multiple struct members.
add_argument_buffer_padding_buffer_type(SPIRType & struct_type,uint32_t & mbr_idx,uint32_t & arg_buff_index,MSLResourceBinding & rez_bind)15851 void CompilerMSL::add_argument_buffer_padding_buffer_type(SPIRType &struct_type, uint32_t &mbr_idx,
15852                                                           uint32_t &arg_buff_index, MSLResourceBinding &rez_bind)
15853 {
15854 	if (!argument_buffer_padding_buffer_type_id)
15855 	{
15856 		uint32_t buff_type_id = ir.increase_bound_by(2);
15857 		auto &buff_type = set<SPIRType>(buff_type_id);
15858 		buff_type.basetype = rez_bind.basetype;
15859 		buff_type.storage = StorageClassUniformConstant;
15860 
15861 		uint32_t ptr_type_id = buff_type_id + 1;
15862 		auto &ptr_type = set<SPIRType>(ptr_type_id);
15863 		ptr_type = buff_type;
15864 		ptr_type.pointer = true;
15865 		ptr_type.pointer_depth++;
15866 		ptr_type.parent_type = buff_type_id;
15867 
15868 		argument_buffer_padding_buffer_type_id = ptr_type_id;
15869 	}
15870 
15871 	for (uint32_t rez_idx = 0; rez_idx < rez_bind.count; rez_idx++)
15872 		add_argument_buffer_padding_type(argument_buffer_padding_buffer_type_id, struct_type, mbr_idx, arg_buff_index, 1);
15873 }
15874 
15875 // Adds an argument buffer padding argument image type as a member of the struct type at the member index.
add_argument_buffer_padding_image_type(SPIRType & struct_type,uint32_t & mbr_idx,uint32_t & arg_buff_index,MSLResourceBinding & rez_bind)15876 void CompilerMSL::add_argument_buffer_padding_image_type(SPIRType &struct_type, uint32_t &mbr_idx,
15877                                                          uint32_t &arg_buff_index, MSLResourceBinding &rez_bind)
15878 {
15879 	if (!argument_buffer_padding_image_type_id)
15880 	{
15881 		uint32_t base_type_id = ir.increase_bound_by(2);
15882 		auto &base_type = set<SPIRType>(base_type_id);
15883 		base_type.basetype = SPIRType::Float;
15884 		base_type.width = 32;
15885 
15886 		uint32_t img_type_id = base_type_id + 1;
15887 		auto &img_type = set<SPIRType>(img_type_id);
15888 		img_type.basetype = SPIRType::Image;
15889 		img_type.storage = StorageClassUniformConstant;
15890 
15891 		img_type.image.type = base_type_id;
15892 		img_type.image.dim = Dim2D;
15893 		img_type.image.depth = false;
15894 		img_type.image.arrayed = false;
15895 		img_type.image.ms = false;
15896 		img_type.image.sampled = 1;
15897 		img_type.image.format = ImageFormatUnknown;
15898 		img_type.image.access = AccessQualifierMax;
15899 
15900 		argument_buffer_padding_image_type_id = img_type_id;
15901 	}
15902 
15903 	add_argument_buffer_padding_type(argument_buffer_padding_image_type_id, struct_type, mbr_idx, arg_buff_index, rez_bind.count);
15904 }
15905 
15906 // Adds an argument buffer padding argument sampler type as a member of the struct type at the member index.
add_argument_buffer_padding_sampler_type(SPIRType & struct_type,uint32_t & mbr_idx,uint32_t & arg_buff_index,MSLResourceBinding & rez_bind)15907 void CompilerMSL::add_argument_buffer_padding_sampler_type(SPIRType &struct_type, uint32_t &mbr_idx,
15908                                                            uint32_t &arg_buff_index, MSLResourceBinding &rez_bind)
15909 {
15910 	if (!argument_buffer_padding_sampler_type_id)
15911 	{
15912 		uint32_t samp_type_id = ir.increase_bound_by(1);
15913 		auto &samp_type = set<SPIRType>(samp_type_id);
15914 		samp_type.basetype = SPIRType::Sampler;
15915 		samp_type.storage = StorageClassUniformConstant;
15916 
15917 		argument_buffer_padding_sampler_type_id = samp_type_id;
15918 	}
15919 
15920 	add_argument_buffer_padding_type(argument_buffer_padding_sampler_type_id, struct_type, mbr_idx, arg_buff_index, rez_bind.count);
15921 }
15922 
15923 // Adds the argument buffer padding argument type as a member of the struct type at the member index.
15924 // Advances both arg_buff_index and mbr_idx to next argument slots.
add_argument_buffer_padding_type(uint32_t mbr_type_id,SPIRType & struct_type,uint32_t & mbr_idx,uint32_t & arg_buff_index,uint32_t count)15925 void CompilerMSL::add_argument_buffer_padding_type(uint32_t mbr_type_id, SPIRType &struct_type, uint32_t &mbr_idx,
15926                                                    uint32_t &arg_buff_index, uint32_t count)
15927 {
15928 	uint32_t type_id = mbr_type_id;
15929 	if (count > 1)
15930 	{
15931 		uint32_t ary_type_id = ir.increase_bound_by(1);
15932 		auto &ary_type = set<SPIRType>(ary_type_id);
15933 		ary_type = get<SPIRType>(type_id);
15934 		ary_type.array.push_back(count);
15935 		ary_type.array_size_literal.push_back(true);
15936 		ary_type.parent_type = type_id;
15937 		type_id = ary_type_id;
15938 	}
15939 
15940 	set_member_name(struct_type.self, mbr_idx, join("_m", arg_buff_index, "_pad"));
15941 	set_extended_member_decoration(struct_type.self, mbr_idx, SPIRVCrossDecorationResourceIndexPrimary, arg_buff_index);
15942 	struct_type.member_types.push_back(type_id);
15943 
15944 	arg_buff_index += count;
15945 	mbr_idx++;
15946 }
15947 
activate_argument_buffer_resources()15948 void CompilerMSL::activate_argument_buffer_resources()
15949 {
15950 	// For ABI compatibility, force-enable all resources which are part of argument buffers.
15951 	ir.for_each_typed_id<SPIRVariable>([&](uint32_t self, const SPIRVariable &) {
15952 		if (!has_decoration(self, DecorationDescriptorSet))
15953 			return;
15954 
15955 		uint32_t desc_set = get_decoration(self, DecorationDescriptorSet);
15956 		if (descriptor_set_is_argument_buffer(desc_set))
15957 			active_interface_variables.insert(self);
15958 	});
15959 }
15960 
using_builtin_array() const15961 bool CompilerMSL::using_builtin_array() const
15962 {
15963 	return msl_options.force_native_arrays || is_using_builtin_array;
15964 }
15965 
set_combined_sampler_suffix(const char * suffix)15966 void CompilerMSL::set_combined_sampler_suffix(const char *suffix)
15967 {
15968 	sampler_name_suffix = suffix;
15969 }
15970 
get_combined_sampler_suffix() const15971 const char *CompilerMSL::get_combined_sampler_suffix() const
15972 {
15973 	return sampler_name_suffix.c_str();
15974 }
15975 
emit_block_hints(const SPIRBlock &)15976 void CompilerMSL::emit_block_hints(const SPIRBlock &)
15977 {
15978 }
15979 
additional_fixed_sample_mask_str() const15980 string CompilerMSL::additional_fixed_sample_mask_str() const
15981 {
15982 	char print_buffer[32];
15983 	sprintf(print_buffer, "0x%x", msl_options.additional_fixed_sample_mask);
15984 	return print_buffer;
15985 }
15986