1 /*
2 * Copyright 2016-2021 The Brenwill Workshop Ltd.
3 * SPDX-License-Identifier: Apache-2.0 OR MIT
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18 /*
19 * At your option, you may choose to accept this material under either:
20 * 1. The Apache License, Version 2.0, found at <http://www.apache.org/licenses/LICENSE-2.0>, or
21 * 2. The MIT License, found at <http://opensource.org/licenses/MIT>.
22 */
23
24 #include "spirv_msl.hpp"
25 #include "GLSL.std.450.h"
26
27 #include <algorithm>
28 #include <assert.h>
29 #include <numeric>
30
31 using namespace spv;
32 using namespace SPIRV_CROSS_NAMESPACE;
33 using namespace std;
34
35 static const uint32_t k_unknown_location = ~0u;
36 static const uint32_t k_unknown_component = ~0u;
37 static const char *force_inline = "static inline __attribute__((always_inline))";
38
CompilerMSL(std::vector<uint32_t> spirv_)39 CompilerMSL::CompilerMSL(std::vector<uint32_t> spirv_)
40 : CompilerGLSL(move(spirv_))
41 {
42 }
43
CompilerMSL(const uint32_t * ir_,size_t word_count)44 CompilerMSL::CompilerMSL(const uint32_t *ir_, size_t word_count)
45 : CompilerGLSL(ir_, word_count)
46 {
47 }
48
CompilerMSL(const ParsedIR & ir_)49 CompilerMSL::CompilerMSL(const ParsedIR &ir_)
50 : CompilerGLSL(ir_)
51 {
52 }
53
CompilerMSL(ParsedIR && ir_)54 CompilerMSL::CompilerMSL(ParsedIR &&ir_)
55 : CompilerGLSL(std::move(ir_))
56 {
57 }
58
add_msl_shader_input(const MSLShaderInput & si)59 void CompilerMSL::add_msl_shader_input(const MSLShaderInput &si)
60 {
61 inputs_by_location[si.location] = si;
62 if (si.builtin != BuiltInMax && !inputs_by_builtin.count(si.builtin))
63 inputs_by_builtin[si.builtin] = si;
64 }
65
add_msl_resource_binding(const MSLResourceBinding & binding)66 void CompilerMSL::add_msl_resource_binding(const MSLResourceBinding &binding)
67 {
68 StageSetBinding tuple = { binding.stage, binding.desc_set, binding.binding };
69 resource_bindings[tuple] = { binding, false };
70
71 // If we might need to pad argument buffer members to positionally align
72 // arg buffer indexes, also maintain a lookup by argument buffer index.
73 if (msl_options.pad_argument_buffer_resources)
74 {
75 StageSetBinding arg_idx_tuple = { binding.stage, binding.desc_set, k_unknown_component };
76
77 #define ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP(rez) \
78 arg_idx_tuple.binding = binding.msl_##rez; \
79 resource_arg_buff_idx_to_binding_number[arg_idx_tuple] = binding.binding
80
81 switch (binding.basetype)
82 {
83 case SPIRType::Void:
84 case SPIRType::Boolean:
85 case SPIRType::SByte:
86 case SPIRType::UByte:
87 case SPIRType::Short:
88 case SPIRType::UShort:
89 case SPIRType::Int:
90 case SPIRType::UInt:
91 case SPIRType::Int64:
92 case SPIRType::UInt64:
93 case SPIRType::AtomicCounter:
94 case SPIRType::Half:
95 case SPIRType::Float:
96 case SPIRType::Double:
97 ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP(buffer);
98 break;
99 case SPIRType::Image:
100 ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP(texture);
101 break;
102 case SPIRType::Sampler:
103 ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP(sampler);
104 break;
105 case SPIRType::SampledImage:
106 ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP(texture);
107 ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP(sampler);
108 break;
109 default:
110 SPIRV_CROSS_THROW("Unexpected argument buffer resource base type. When padding argument buffer elements, "
111 "all descriptor set resources must be supplied with a base type by the app.");
112 }
113 #undef ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP
114 }
115 }
116
add_dynamic_buffer(uint32_t desc_set,uint32_t binding,uint32_t index)117 void CompilerMSL::add_dynamic_buffer(uint32_t desc_set, uint32_t binding, uint32_t index)
118 {
119 SetBindingPair pair = { desc_set, binding };
120 buffers_requiring_dynamic_offset[pair] = { index, 0 };
121 }
122
add_inline_uniform_block(uint32_t desc_set,uint32_t binding)123 void CompilerMSL::add_inline_uniform_block(uint32_t desc_set, uint32_t binding)
124 {
125 SetBindingPair pair = { desc_set, binding };
126 inline_uniform_blocks.insert(pair);
127 }
128
add_discrete_descriptor_set(uint32_t desc_set)129 void CompilerMSL::add_discrete_descriptor_set(uint32_t desc_set)
130 {
131 if (desc_set < kMaxArgumentBuffers)
132 argument_buffer_discrete_mask |= 1u << desc_set;
133 }
134
set_argument_buffer_device_address_space(uint32_t desc_set,bool device_storage)135 void CompilerMSL::set_argument_buffer_device_address_space(uint32_t desc_set, bool device_storage)
136 {
137 if (desc_set < kMaxArgumentBuffers)
138 {
139 if (device_storage)
140 argument_buffer_device_storage_mask |= 1u << desc_set;
141 else
142 argument_buffer_device_storage_mask &= ~(1u << desc_set);
143 }
144 }
145
is_msl_shader_input_used(uint32_t location)146 bool CompilerMSL::is_msl_shader_input_used(uint32_t location)
147 {
148 // Don't report internal location allocations to app.
149 return location_inputs_in_use.count(location) != 0 &&
150 location_inputs_in_use_fallback.count(location) == 0;
151 }
152
get_automatic_builtin_input_location(spv::BuiltIn builtin) const153 uint32_t CompilerMSL::get_automatic_builtin_input_location(spv::BuiltIn builtin) const
154 {
155 auto itr = builtin_to_automatic_input_location.find(builtin);
156 if (itr == builtin_to_automatic_input_location.end())
157 return k_unknown_location;
158 else
159 return itr->second;
160 }
161
is_msl_resource_binding_used(ExecutionModel model,uint32_t desc_set,uint32_t binding) const162 bool CompilerMSL::is_msl_resource_binding_used(ExecutionModel model, uint32_t desc_set, uint32_t binding) const
163 {
164 StageSetBinding tuple = { model, desc_set, binding };
165 auto itr = resource_bindings.find(tuple);
166 return itr != end(resource_bindings) && itr->second.second;
167 }
168
169 // Returns the size of the array of resources used by the variable with the specified id.
170 // The returned value is retrieved from the resource binding added using add_msl_resource_binding().
get_resource_array_size(uint32_t id) const171 uint32_t CompilerMSL::get_resource_array_size(uint32_t id) const
172 {
173 StageSetBinding tuple = { get_entry_point().model, get_decoration(id, DecorationDescriptorSet),
174 get_decoration(id, DecorationBinding) };
175 auto itr = resource_bindings.find(tuple);
176 return itr != end(resource_bindings) ? itr->second.first.count : 0;
177 }
178
get_automatic_msl_resource_binding(uint32_t id) const179 uint32_t CompilerMSL::get_automatic_msl_resource_binding(uint32_t id) const
180 {
181 return get_extended_decoration(id, SPIRVCrossDecorationResourceIndexPrimary);
182 }
183
get_automatic_msl_resource_binding_secondary(uint32_t id) const184 uint32_t CompilerMSL::get_automatic_msl_resource_binding_secondary(uint32_t id) const
185 {
186 return get_extended_decoration(id, SPIRVCrossDecorationResourceIndexSecondary);
187 }
188
get_automatic_msl_resource_binding_tertiary(uint32_t id) const189 uint32_t CompilerMSL::get_automatic_msl_resource_binding_tertiary(uint32_t id) const
190 {
191 return get_extended_decoration(id, SPIRVCrossDecorationResourceIndexTertiary);
192 }
193
get_automatic_msl_resource_binding_quaternary(uint32_t id) const194 uint32_t CompilerMSL::get_automatic_msl_resource_binding_quaternary(uint32_t id) const
195 {
196 return get_extended_decoration(id, SPIRVCrossDecorationResourceIndexQuaternary);
197 }
198
set_fragment_output_components(uint32_t location,uint32_t components)199 void CompilerMSL::set_fragment_output_components(uint32_t location, uint32_t components)
200 {
201 fragment_output_components[location] = components;
202 }
203
builtin_translates_to_nonarray(spv::BuiltIn builtin) const204 bool CompilerMSL::builtin_translates_to_nonarray(spv::BuiltIn builtin) const
205 {
206 return (builtin == BuiltInSampleMask);
207 }
208
build_implicit_builtins()209 void CompilerMSL::build_implicit_builtins()
210 {
211 bool need_sample_pos = active_input_builtins.get(BuiltInSamplePosition);
212 bool need_vertex_params = capture_output_to_buffer && get_execution_model() == ExecutionModelVertex &&
213 !msl_options.vertex_for_tessellation;
214 bool need_tesc_params = get_execution_model() == ExecutionModelTessellationControl;
215 bool need_subgroup_mask =
216 active_input_builtins.get(BuiltInSubgroupEqMask) || active_input_builtins.get(BuiltInSubgroupGeMask) ||
217 active_input_builtins.get(BuiltInSubgroupGtMask) || active_input_builtins.get(BuiltInSubgroupLeMask) ||
218 active_input_builtins.get(BuiltInSubgroupLtMask);
219 bool need_subgroup_ge_mask = !msl_options.is_ios() && (active_input_builtins.get(BuiltInSubgroupGeMask) ||
220 active_input_builtins.get(BuiltInSubgroupGtMask));
221 bool need_multiview = get_execution_model() == ExecutionModelVertex && !msl_options.view_index_from_device_index &&
222 msl_options.multiview_layered_rendering &&
223 (msl_options.multiview || active_input_builtins.get(BuiltInViewIndex));
224 bool need_dispatch_base =
225 msl_options.dispatch_base && get_execution_model() == ExecutionModelGLCompute &&
226 (active_input_builtins.get(BuiltInWorkgroupId) || active_input_builtins.get(BuiltInGlobalInvocationId));
227 bool need_grid_params = get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation;
228 bool need_vertex_base_params =
229 need_grid_params &&
230 (active_input_builtins.get(BuiltInVertexId) || active_input_builtins.get(BuiltInVertexIndex) ||
231 active_input_builtins.get(BuiltInBaseVertex) || active_input_builtins.get(BuiltInInstanceId) ||
232 active_input_builtins.get(BuiltInInstanceIndex) || active_input_builtins.get(BuiltInBaseInstance));
233 bool need_local_invocation_index = msl_options.emulate_subgroups && active_input_builtins.get(BuiltInSubgroupId);
234 bool need_workgroup_size = msl_options.emulate_subgroups && active_input_builtins.get(BuiltInNumSubgroups);
235
236 if (need_subpass_input || need_sample_pos || need_subgroup_mask || need_vertex_params || need_tesc_params ||
237 need_multiview || need_dispatch_base || need_vertex_base_params || need_grid_params || needs_sample_id ||
238 needs_subgroup_invocation_id || needs_subgroup_size || has_additional_fixed_sample_mask() || need_local_invocation_index ||
239 need_workgroup_size)
240 {
241 bool has_frag_coord = false;
242 bool has_sample_id = false;
243 bool has_vertex_idx = false;
244 bool has_base_vertex = false;
245 bool has_instance_idx = false;
246 bool has_base_instance = false;
247 bool has_invocation_id = false;
248 bool has_primitive_id = false;
249 bool has_subgroup_invocation_id = false;
250 bool has_subgroup_size = false;
251 bool has_view_idx = false;
252 bool has_layer = false;
253 bool has_local_invocation_index = false;
254 bool has_workgroup_size = false;
255 uint32_t workgroup_id_type = 0;
256
257 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
258 if (var.storage != StorageClassInput && var.storage != StorageClassOutput)
259 return;
260 if (!interface_variable_exists_in_entry_point(var.self))
261 return;
262 if (!has_decoration(var.self, DecorationBuiltIn))
263 return;
264
265 BuiltIn builtin = ir.meta[var.self].decoration.builtin_type;
266
267 if (var.storage == StorageClassOutput)
268 {
269 if (has_additional_fixed_sample_mask() && builtin == BuiltInSampleMask)
270 {
271 builtin_sample_mask_id = var.self;
272 mark_implicit_builtin(StorageClassOutput, BuiltInSampleMask, var.self);
273 does_shader_write_sample_mask = true;
274 }
275 }
276
277 if (var.storage != StorageClassInput)
278 return;
279
280 // Use Metal's native frame-buffer fetch API for subpass inputs.
281 if (need_subpass_input && (!msl_options.use_framebuffer_fetch_subpasses))
282 {
283 switch (builtin)
284 {
285 case BuiltInFragCoord:
286 mark_implicit_builtin(StorageClassInput, BuiltInFragCoord, var.self);
287 builtin_frag_coord_id = var.self;
288 has_frag_coord = true;
289 break;
290 case BuiltInLayer:
291 if (!msl_options.arrayed_subpass_input || msl_options.multiview)
292 break;
293 mark_implicit_builtin(StorageClassInput, BuiltInLayer, var.self);
294 builtin_layer_id = var.self;
295 has_layer = true;
296 break;
297 case BuiltInViewIndex:
298 if (!msl_options.multiview)
299 break;
300 mark_implicit_builtin(StorageClassInput, BuiltInViewIndex, var.self);
301 builtin_view_idx_id = var.self;
302 has_view_idx = true;
303 break;
304 default:
305 break;
306 }
307 }
308
309 if ((need_sample_pos || needs_sample_id) && builtin == BuiltInSampleId)
310 {
311 builtin_sample_id_id = var.self;
312 mark_implicit_builtin(StorageClassInput, BuiltInSampleId, var.self);
313 has_sample_id = true;
314 }
315
316 if (need_vertex_params)
317 {
318 switch (builtin)
319 {
320 case BuiltInVertexIndex:
321 builtin_vertex_idx_id = var.self;
322 mark_implicit_builtin(StorageClassInput, BuiltInVertexIndex, var.self);
323 has_vertex_idx = true;
324 break;
325 case BuiltInBaseVertex:
326 builtin_base_vertex_id = var.self;
327 mark_implicit_builtin(StorageClassInput, BuiltInBaseVertex, var.self);
328 has_base_vertex = true;
329 break;
330 case BuiltInInstanceIndex:
331 builtin_instance_idx_id = var.self;
332 mark_implicit_builtin(StorageClassInput, BuiltInInstanceIndex, var.self);
333 has_instance_idx = true;
334 break;
335 case BuiltInBaseInstance:
336 builtin_base_instance_id = var.self;
337 mark_implicit_builtin(StorageClassInput, BuiltInBaseInstance, var.self);
338 has_base_instance = true;
339 break;
340 default:
341 break;
342 }
343 }
344
345 if (need_tesc_params)
346 {
347 switch (builtin)
348 {
349 case BuiltInInvocationId:
350 builtin_invocation_id_id = var.self;
351 mark_implicit_builtin(StorageClassInput, BuiltInInvocationId, var.self);
352 has_invocation_id = true;
353 break;
354 case BuiltInPrimitiveId:
355 builtin_primitive_id_id = var.self;
356 mark_implicit_builtin(StorageClassInput, BuiltInPrimitiveId, var.self);
357 has_primitive_id = true;
358 break;
359 default:
360 break;
361 }
362 }
363
364 if ((need_subgroup_mask || needs_subgroup_invocation_id) && builtin == BuiltInSubgroupLocalInvocationId)
365 {
366 builtin_subgroup_invocation_id_id = var.self;
367 mark_implicit_builtin(StorageClassInput, BuiltInSubgroupLocalInvocationId, var.self);
368 has_subgroup_invocation_id = true;
369 }
370
371 if ((need_subgroup_ge_mask || needs_subgroup_size) && builtin == BuiltInSubgroupSize)
372 {
373 builtin_subgroup_size_id = var.self;
374 mark_implicit_builtin(StorageClassInput, BuiltInSubgroupSize, var.self);
375 has_subgroup_size = true;
376 }
377
378 if (need_multiview)
379 {
380 switch (builtin)
381 {
382 case BuiltInInstanceIndex:
383 // The view index here is derived from the instance index.
384 builtin_instance_idx_id = var.self;
385 mark_implicit_builtin(StorageClassInput, BuiltInInstanceIndex, var.self);
386 has_instance_idx = true;
387 break;
388 case BuiltInBaseInstance:
389 // If a non-zero base instance is used, we need to adjust for it when calculating the view index.
390 builtin_base_instance_id = var.self;
391 mark_implicit_builtin(StorageClassInput, BuiltInBaseInstance, var.self);
392 has_base_instance = true;
393 break;
394 case BuiltInViewIndex:
395 builtin_view_idx_id = var.self;
396 mark_implicit_builtin(StorageClassInput, BuiltInViewIndex, var.self);
397 has_view_idx = true;
398 break;
399 default:
400 break;
401 }
402 }
403
404 if (need_local_invocation_index && builtin == BuiltInLocalInvocationIndex)
405 {
406 builtin_local_invocation_index_id = var.self;
407 mark_implicit_builtin(StorageClassInput, BuiltInLocalInvocationIndex, var.self);
408 has_local_invocation_index = true;
409 }
410
411 if (need_workgroup_size && builtin == BuiltInLocalInvocationId)
412 {
413 builtin_workgroup_size_id = var.self;
414 mark_implicit_builtin(StorageClassInput, BuiltInWorkgroupSize, var.self);
415 has_workgroup_size = true;
416 }
417
418 // The base workgroup needs to have the same type and vector size
419 // as the workgroup or invocation ID, so keep track of the type that
420 // was used.
421 if (need_dispatch_base && workgroup_id_type == 0 &&
422 (builtin == BuiltInWorkgroupId || builtin == BuiltInGlobalInvocationId))
423 workgroup_id_type = var.basetype;
424 });
425
426 // Use Metal's native frame-buffer fetch API for subpass inputs.
427 if ((!has_frag_coord || (msl_options.multiview && !has_view_idx) ||
428 (msl_options.arrayed_subpass_input && !msl_options.multiview && !has_layer)) &&
429 (!msl_options.use_framebuffer_fetch_subpasses) && need_subpass_input)
430 {
431 if (!has_frag_coord)
432 {
433 uint32_t offset = ir.increase_bound_by(3);
434 uint32_t type_id = offset;
435 uint32_t type_ptr_id = offset + 1;
436 uint32_t var_id = offset + 2;
437
438 // Create gl_FragCoord.
439 SPIRType vec4_type;
440 vec4_type.basetype = SPIRType::Float;
441 vec4_type.width = 32;
442 vec4_type.vecsize = 4;
443 set<SPIRType>(type_id, vec4_type);
444
445 SPIRType vec4_type_ptr;
446 vec4_type_ptr = vec4_type;
447 vec4_type_ptr.pointer = true;
448 vec4_type_ptr.pointer_depth++;
449 vec4_type_ptr.parent_type = type_id;
450 vec4_type_ptr.storage = StorageClassInput;
451 auto &ptr_type = set<SPIRType>(type_ptr_id, vec4_type_ptr);
452 ptr_type.self = type_id;
453
454 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
455 set_decoration(var_id, DecorationBuiltIn, BuiltInFragCoord);
456 builtin_frag_coord_id = var_id;
457 mark_implicit_builtin(StorageClassInput, BuiltInFragCoord, var_id);
458 }
459
460 if (!has_layer && msl_options.arrayed_subpass_input && !msl_options.multiview)
461 {
462 uint32_t offset = ir.increase_bound_by(2);
463 uint32_t type_ptr_id = offset;
464 uint32_t var_id = offset + 1;
465
466 // Create gl_Layer.
467 SPIRType uint_type_ptr;
468 uint_type_ptr = get_uint_type();
469 uint_type_ptr.pointer = true;
470 uint_type_ptr.pointer_depth++;
471 uint_type_ptr.parent_type = get_uint_type_id();
472 uint_type_ptr.storage = StorageClassInput;
473 auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
474 ptr_type.self = get_uint_type_id();
475
476 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
477 set_decoration(var_id, DecorationBuiltIn, BuiltInLayer);
478 builtin_layer_id = var_id;
479 mark_implicit_builtin(StorageClassInput, BuiltInLayer, var_id);
480 }
481
482 if (!has_view_idx && msl_options.multiview)
483 {
484 uint32_t offset = ir.increase_bound_by(2);
485 uint32_t type_ptr_id = offset;
486 uint32_t var_id = offset + 1;
487
488 // Create gl_ViewIndex.
489 SPIRType uint_type_ptr;
490 uint_type_ptr = get_uint_type();
491 uint_type_ptr.pointer = true;
492 uint_type_ptr.pointer_depth++;
493 uint_type_ptr.parent_type = get_uint_type_id();
494 uint_type_ptr.storage = StorageClassInput;
495 auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
496 ptr_type.self = get_uint_type_id();
497
498 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
499 set_decoration(var_id, DecorationBuiltIn, BuiltInViewIndex);
500 builtin_view_idx_id = var_id;
501 mark_implicit_builtin(StorageClassInput, BuiltInViewIndex, var_id);
502 }
503 }
504
505 if (!has_sample_id && (need_sample_pos || needs_sample_id))
506 {
507 uint32_t offset = ir.increase_bound_by(2);
508 uint32_t type_ptr_id = offset;
509 uint32_t var_id = offset + 1;
510
511 // Create gl_SampleID.
512 SPIRType uint_type_ptr;
513 uint_type_ptr = get_uint_type();
514 uint_type_ptr.pointer = true;
515 uint_type_ptr.pointer_depth++;
516 uint_type_ptr.parent_type = get_uint_type_id();
517 uint_type_ptr.storage = StorageClassInput;
518 auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
519 ptr_type.self = get_uint_type_id();
520
521 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
522 set_decoration(var_id, DecorationBuiltIn, BuiltInSampleId);
523 builtin_sample_id_id = var_id;
524 mark_implicit_builtin(StorageClassInput, BuiltInSampleId, var_id);
525 }
526
527 if ((need_vertex_params && (!has_vertex_idx || !has_base_vertex || !has_instance_idx || !has_base_instance)) ||
528 (need_multiview && (!has_instance_idx || !has_base_instance || !has_view_idx)))
529 {
530 uint32_t type_ptr_id = ir.increase_bound_by(1);
531
532 SPIRType uint_type_ptr;
533 uint_type_ptr = get_uint_type();
534 uint_type_ptr.pointer = true;
535 uint_type_ptr.pointer_depth++;
536 uint_type_ptr.parent_type = get_uint_type_id();
537 uint_type_ptr.storage = StorageClassInput;
538 auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
539 ptr_type.self = get_uint_type_id();
540
541 if (need_vertex_params && !has_vertex_idx)
542 {
543 uint32_t var_id = ir.increase_bound_by(1);
544
545 // Create gl_VertexIndex.
546 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
547 set_decoration(var_id, DecorationBuiltIn, BuiltInVertexIndex);
548 builtin_vertex_idx_id = var_id;
549 mark_implicit_builtin(StorageClassInput, BuiltInVertexIndex, var_id);
550 }
551
552 if (need_vertex_params && !has_base_vertex)
553 {
554 uint32_t var_id = ir.increase_bound_by(1);
555
556 // Create gl_BaseVertex.
557 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
558 set_decoration(var_id, DecorationBuiltIn, BuiltInBaseVertex);
559 builtin_base_vertex_id = var_id;
560 mark_implicit_builtin(StorageClassInput, BuiltInBaseVertex, var_id);
561 }
562
563 if (!has_instance_idx) // Needed by both multiview and tessellation
564 {
565 uint32_t var_id = ir.increase_bound_by(1);
566
567 // Create gl_InstanceIndex.
568 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
569 set_decoration(var_id, DecorationBuiltIn, BuiltInInstanceIndex);
570 builtin_instance_idx_id = var_id;
571 mark_implicit_builtin(StorageClassInput, BuiltInInstanceIndex, var_id);
572 }
573
574 if (!has_base_instance) // Needed by both multiview and tessellation
575 {
576 uint32_t var_id = ir.increase_bound_by(1);
577
578 // Create gl_BaseInstance.
579 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
580 set_decoration(var_id, DecorationBuiltIn, BuiltInBaseInstance);
581 builtin_base_instance_id = var_id;
582 mark_implicit_builtin(StorageClassInput, BuiltInBaseInstance, var_id);
583 }
584
585 if (need_multiview)
586 {
587 // Multiview shaders are not allowed to write to gl_Layer, ostensibly because
588 // it is implicitly written from gl_ViewIndex, but we have to do that explicitly.
589 // Note that we can't just abuse gl_ViewIndex for this purpose: it's an input, but
590 // gl_Layer is an output in vertex-pipeline shaders.
591 uint32_t type_ptr_out_id = ir.increase_bound_by(2);
592 SPIRType uint_type_ptr_out;
593 uint_type_ptr_out = get_uint_type();
594 uint_type_ptr_out.pointer = true;
595 uint_type_ptr_out.pointer_depth++;
596 uint_type_ptr_out.parent_type = get_uint_type_id();
597 uint_type_ptr_out.storage = StorageClassOutput;
598 auto &ptr_out_type = set<SPIRType>(type_ptr_out_id, uint_type_ptr_out);
599 ptr_out_type.self = get_uint_type_id();
600 uint32_t var_id = type_ptr_out_id + 1;
601 set<SPIRVariable>(var_id, type_ptr_out_id, StorageClassOutput);
602 set_decoration(var_id, DecorationBuiltIn, BuiltInLayer);
603 builtin_layer_id = var_id;
604 mark_implicit_builtin(StorageClassOutput, BuiltInLayer, var_id);
605 }
606
607 if (need_multiview && !has_view_idx)
608 {
609 uint32_t var_id = ir.increase_bound_by(1);
610
611 // Create gl_ViewIndex.
612 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
613 set_decoration(var_id, DecorationBuiltIn, BuiltInViewIndex);
614 builtin_view_idx_id = var_id;
615 mark_implicit_builtin(StorageClassInput, BuiltInViewIndex, var_id);
616 }
617 }
618
619 if ((need_tesc_params && (msl_options.multi_patch_workgroup || !has_invocation_id || !has_primitive_id)) ||
620 need_grid_params)
621 {
622 uint32_t type_ptr_id = ir.increase_bound_by(1);
623
624 SPIRType uint_type_ptr;
625 uint_type_ptr = get_uint_type();
626 uint_type_ptr.pointer = true;
627 uint_type_ptr.pointer_depth++;
628 uint_type_ptr.parent_type = get_uint_type_id();
629 uint_type_ptr.storage = StorageClassInput;
630 auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
631 ptr_type.self = get_uint_type_id();
632
633 if (msl_options.multi_patch_workgroup || need_grid_params)
634 {
635 uint32_t var_id = ir.increase_bound_by(1);
636
637 // Create gl_GlobalInvocationID.
638 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
639 set_decoration(var_id, DecorationBuiltIn, BuiltInGlobalInvocationId);
640 builtin_invocation_id_id = var_id;
641 mark_implicit_builtin(StorageClassInput, BuiltInGlobalInvocationId, var_id);
642 }
643 else if (need_tesc_params && !has_invocation_id)
644 {
645 uint32_t var_id = ir.increase_bound_by(1);
646
647 // Create gl_InvocationID.
648 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
649 set_decoration(var_id, DecorationBuiltIn, BuiltInInvocationId);
650 builtin_invocation_id_id = var_id;
651 mark_implicit_builtin(StorageClassInput, BuiltInInvocationId, var_id);
652 }
653
654 if (need_tesc_params && !has_primitive_id)
655 {
656 uint32_t var_id = ir.increase_bound_by(1);
657
658 // Create gl_PrimitiveID.
659 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
660 set_decoration(var_id, DecorationBuiltIn, BuiltInPrimitiveId);
661 builtin_primitive_id_id = var_id;
662 mark_implicit_builtin(StorageClassInput, BuiltInPrimitiveId, var_id);
663 }
664
665 if (need_grid_params)
666 {
667 uint32_t var_id = ir.increase_bound_by(1);
668
669 set<SPIRVariable>(var_id, build_extended_vector_type(get_uint_type_id(), 3), StorageClassInput);
670 set_extended_decoration(var_id, SPIRVCrossDecorationBuiltInStageInputSize);
671 get_entry_point().interface_variables.push_back(var_id);
672 set_name(var_id, "spvStageInputSize");
673 builtin_stage_input_size_id = var_id;
674 }
675 }
676
677 if (!has_subgroup_invocation_id && (need_subgroup_mask || needs_subgroup_invocation_id))
678 {
679 uint32_t offset = ir.increase_bound_by(2);
680 uint32_t type_ptr_id = offset;
681 uint32_t var_id = offset + 1;
682
683 // Create gl_SubgroupInvocationID.
684 SPIRType uint_type_ptr;
685 uint_type_ptr = get_uint_type();
686 uint_type_ptr.pointer = true;
687 uint_type_ptr.pointer_depth++;
688 uint_type_ptr.parent_type = get_uint_type_id();
689 uint_type_ptr.storage = StorageClassInput;
690 auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
691 ptr_type.self = get_uint_type_id();
692
693 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
694 set_decoration(var_id, DecorationBuiltIn, BuiltInSubgroupLocalInvocationId);
695 builtin_subgroup_invocation_id_id = var_id;
696 mark_implicit_builtin(StorageClassInput, BuiltInSubgroupLocalInvocationId, var_id);
697 }
698
699 if (!has_subgroup_size && (need_subgroup_ge_mask || needs_subgroup_size))
700 {
701 uint32_t offset = ir.increase_bound_by(2);
702 uint32_t type_ptr_id = offset;
703 uint32_t var_id = offset + 1;
704
705 // Create gl_SubgroupSize.
706 SPIRType uint_type_ptr;
707 uint_type_ptr = get_uint_type();
708 uint_type_ptr.pointer = true;
709 uint_type_ptr.pointer_depth++;
710 uint_type_ptr.parent_type = get_uint_type_id();
711 uint_type_ptr.storage = StorageClassInput;
712 auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
713 ptr_type.self = get_uint_type_id();
714
715 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
716 set_decoration(var_id, DecorationBuiltIn, BuiltInSubgroupSize);
717 builtin_subgroup_size_id = var_id;
718 mark_implicit_builtin(StorageClassInput, BuiltInSubgroupSize, var_id);
719 }
720
721 if (need_dispatch_base || need_vertex_base_params)
722 {
723 if (workgroup_id_type == 0)
724 workgroup_id_type = build_extended_vector_type(get_uint_type_id(), 3);
725 uint32_t var_id;
726 if (msl_options.supports_msl_version(1, 2))
727 {
728 // If we have MSL 1.2, we can (ab)use the [[grid_origin]] builtin
729 // to convey this information and save a buffer slot.
730 uint32_t offset = ir.increase_bound_by(1);
731 var_id = offset;
732
733 set<SPIRVariable>(var_id, workgroup_id_type, StorageClassInput);
734 set_extended_decoration(var_id, SPIRVCrossDecorationBuiltInDispatchBase);
735 get_entry_point().interface_variables.push_back(var_id);
736 }
737 else
738 {
739 // Otherwise, we need to fall back to a good ol' fashioned buffer.
740 uint32_t offset = ir.increase_bound_by(2);
741 var_id = offset;
742 uint32_t type_id = offset + 1;
743
744 SPIRType var_type = get<SPIRType>(workgroup_id_type);
745 var_type.storage = StorageClassUniform;
746 set<SPIRType>(type_id, var_type);
747
748 set<SPIRVariable>(var_id, type_id, StorageClassUniform);
749 // This should never match anything.
750 set_decoration(var_id, DecorationDescriptorSet, ~(5u));
751 set_decoration(var_id, DecorationBinding, msl_options.indirect_params_buffer_index);
752 set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary,
753 msl_options.indirect_params_buffer_index);
754 }
755 set_name(var_id, "spvDispatchBase");
756 builtin_dispatch_base_id = var_id;
757 }
758
759 if (has_additional_fixed_sample_mask() && !does_shader_write_sample_mask)
760 {
761 uint32_t offset = ir.increase_bound_by(2);
762 uint32_t var_id = offset + 1;
763
764 // Create gl_SampleMask.
765 SPIRType uint_type_ptr_out;
766 uint_type_ptr_out = get_uint_type();
767 uint_type_ptr_out.pointer = true;
768 uint_type_ptr_out.pointer_depth++;
769 uint_type_ptr_out.parent_type = get_uint_type_id();
770 uint_type_ptr_out.storage = StorageClassOutput;
771
772 auto &ptr_out_type = set<SPIRType>(offset, uint_type_ptr_out);
773 ptr_out_type.self = get_uint_type_id();
774 set<SPIRVariable>(var_id, offset, StorageClassOutput);
775 set_decoration(var_id, DecorationBuiltIn, BuiltInSampleMask);
776 builtin_sample_mask_id = var_id;
777 mark_implicit_builtin(StorageClassOutput, BuiltInSampleMask, var_id);
778 }
779
780 if (need_local_invocation_index && !has_local_invocation_index)
781 {
782 uint32_t offset = ir.increase_bound_by(2);
783 uint32_t type_ptr_id = offset;
784 uint32_t var_id = offset + 1;
785
786 // Create gl_LocalInvocationIndex.
787 SPIRType uint_type_ptr;
788 uint_type_ptr = get_uint_type();
789 uint_type_ptr.pointer = true;
790 uint_type_ptr.pointer_depth++;
791 uint_type_ptr.parent_type = get_uint_type_id();
792 uint_type_ptr.storage = StorageClassInput;
793
794 auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
795 ptr_type.self = get_uint_type_id();
796 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
797 set_decoration(var_id, DecorationBuiltIn, BuiltInLocalInvocationIndex);
798 builtin_local_invocation_index_id = var_id;
799 mark_implicit_builtin(StorageClassInput, BuiltInLocalInvocationIndex, var_id);
800 }
801
802 if (need_workgroup_size && !has_workgroup_size)
803 {
804 uint32_t offset = ir.increase_bound_by(2);
805 uint32_t type_ptr_id = offset;
806 uint32_t var_id = offset + 1;
807
808 // Create gl_WorkgroupSize.
809 uint32_t type_id = build_extended_vector_type(get_uint_type_id(), 3);
810 SPIRType uint_type_ptr = get<SPIRType>(type_id);
811 uint_type_ptr.pointer = true;
812 uint_type_ptr.pointer_depth++;
813 uint_type_ptr.parent_type = type_id;
814 uint_type_ptr.storage = StorageClassInput;
815
816 auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
817 ptr_type.self = type_id;
818 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
819 set_decoration(var_id, DecorationBuiltIn, BuiltInWorkgroupSize);
820 builtin_workgroup_size_id = var_id;
821 mark_implicit_builtin(StorageClassInput, BuiltInWorkgroupSize, var_id);
822 }
823 }
824
825 if (needs_swizzle_buffer_def)
826 {
827 uint32_t var_id = build_constant_uint_array_pointer();
828 set_name(var_id, "spvSwizzleConstants");
829 // This should never match anything.
830 set_decoration(var_id, DecorationDescriptorSet, kSwizzleBufferBinding);
831 set_decoration(var_id, DecorationBinding, msl_options.swizzle_buffer_index);
832 set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary, msl_options.swizzle_buffer_index);
833 swizzle_buffer_id = var_id;
834 }
835
836 if (!buffers_requiring_array_length.empty())
837 {
838 uint32_t var_id = build_constant_uint_array_pointer();
839 set_name(var_id, "spvBufferSizeConstants");
840 // This should never match anything.
841 set_decoration(var_id, DecorationDescriptorSet, kBufferSizeBufferBinding);
842 set_decoration(var_id, DecorationBinding, msl_options.buffer_size_buffer_index);
843 set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary, msl_options.buffer_size_buffer_index);
844 buffer_size_buffer_id = var_id;
845 }
846
847 if (needs_view_mask_buffer())
848 {
849 uint32_t var_id = build_constant_uint_array_pointer();
850 set_name(var_id, "spvViewMask");
851 // This should never match anything.
852 set_decoration(var_id, DecorationDescriptorSet, ~(4u));
853 set_decoration(var_id, DecorationBinding, msl_options.view_mask_buffer_index);
854 set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary, msl_options.view_mask_buffer_index);
855 view_mask_buffer_id = var_id;
856 }
857
858 if (!buffers_requiring_dynamic_offset.empty())
859 {
860 uint32_t var_id = build_constant_uint_array_pointer();
861 set_name(var_id, "spvDynamicOffsets");
862 // This should never match anything.
863 set_decoration(var_id, DecorationDescriptorSet, ~(5u));
864 set_decoration(var_id, DecorationBinding, msl_options.dynamic_offsets_buffer_index);
865 set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary,
866 msl_options.dynamic_offsets_buffer_index);
867 dynamic_offsets_buffer_id = var_id;
868 }
869
870 // If we're returning a struct from a vertex-like entry point, we must return a position attribute.
871 bool need_position =
872 (get_execution_model() == ExecutionModelVertex ||
873 get_execution_model() == ExecutionModelTessellationEvaluation) &&
874 !capture_output_to_buffer && !get_is_rasterization_disabled() &&
875 !active_output_builtins.get(BuiltInPosition);
876
877 if (need_position)
878 {
879 // If we can get away with returning void from entry point, we don't need to care.
880 // If there is at least one other stage output, we need to return [[position]],
881 // so we need to create one if it doesn't appear in the SPIR-V. Before adding the
882 // implicit variable, check if it actually exists already, but just has not been used
883 // or initialized, and if so, mark it as active, and do not create the implicit variable.
884 bool has_output = false;
885 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
886 if (var.storage == StorageClassOutput && interface_variable_exists_in_entry_point(var.self))
887 {
888 has_output = true;
889
890 // Check if the var is the Position builtin
891 if (has_decoration(var.self, DecorationBuiltIn) && get_decoration(var.self, DecorationBuiltIn) == BuiltInPosition)
892 active_output_builtins.set(BuiltInPosition);
893
894 // If the var is a struct, check if any members is the Position builtin
895 auto &var_type = get_variable_element_type(var);
896 if (var_type.basetype == SPIRType::Struct)
897 {
898 auto mbr_cnt = var_type.member_types.size();
899 for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
900 {
901 auto builtin = BuiltInMax;
902 bool is_builtin = is_member_builtin(var_type, mbr_idx, &builtin);
903 if (is_builtin && builtin == BuiltInPosition)
904 active_output_builtins.set(BuiltInPosition);
905 }
906 }
907 }
908 });
909 need_position = has_output && !active_output_builtins.get(BuiltInPosition);
910 }
911
912 if (need_position)
913 {
914 uint32_t offset = ir.increase_bound_by(3);
915 uint32_t type_id = offset;
916 uint32_t type_ptr_id = offset + 1;
917 uint32_t var_id = offset + 2;
918
919 // Create gl_Position.
920 SPIRType vec4_type;
921 vec4_type.basetype = SPIRType::Float;
922 vec4_type.width = 32;
923 vec4_type.vecsize = 4;
924 set<SPIRType>(type_id, vec4_type);
925
926 SPIRType vec4_type_ptr;
927 vec4_type_ptr = vec4_type;
928 vec4_type_ptr.pointer = true;
929 vec4_type_ptr.pointer_depth++;
930 vec4_type_ptr.parent_type = type_id;
931 vec4_type_ptr.storage = StorageClassOutput;
932 auto &ptr_type = set<SPIRType>(type_ptr_id, vec4_type_ptr);
933 ptr_type.self = type_id;
934
935 set<SPIRVariable>(var_id, type_ptr_id, StorageClassOutput);
936 set_decoration(var_id, DecorationBuiltIn, BuiltInPosition);
937 mark_implicit_builtin(StorageClassOutput, BuiltInPosition, var_id);
938 }
939 }
940
941 // Checks if the specified builtin variable (e.g. gl_InstanceIndex) is marked as active.
942 // If not, it marks it as active and forces a recompilation.
943 // This might be used when the optimization of inactive builtins was too optimistic (e.g. when "spvOut" is emitted).
ensure_builtin(spv::StorageClass storage,spv::BuiltIn builtin)944 void CompilerMSL::ensure_builtin(spv::StorageClass storage, spv::BuiltIn builtin)
945 {
946 Bitset *active_builtins = nullptr;
947 switch (storage)
948 {
949 case StorageClassInput:
950 active_builtins = &active_input_builtins;
951 break;
952
953 case StorageClassOutput:
954 active_builtins = &active_output_builtins;
955 break;
956
957 default:
958 break;
959 }
960
961 // At this point, the specified builtin variable must have already been declared in the entry point.
962 // If not, mark as active and force recompile.
963 if (active_builtins != nullptr && !active_builtins->get(builtin))
964 {
965 active_builtins->set(builtin);
966 force_recompile();
967 }
968 }
969
mark_implicit_builtin(StorageClass storage,BuiltIn builtin,uint32_t id)970 void CompilerMSL::mark_implicit_builtin(StorageClass storage, BuiltIn builtin, uint32_t id)
971 {
972 Bitset *active_builtins = nullptr;
973 switch (storage)
974 {
975 case StorageClassInput:
976 active_builtins = &active_input_builtins;
977 break;
978
979 case StorageClassOutput:
980 active_builtins = &active_output_builtins;
981 break;
982
983 default:
984 break;
985 }
986
987 assert(active_builtins != nullptr);
988 active_builtins->set(builtin);
989
990 auto &var = get_entry_point().interface_variables;
991 if (find(begin(var), end(var), VariableID(id)) == end(var))
992 var.push_back(id);
993 }
994
build_constant_uint_array_pointer()995 uint32_t CompilerMSL::build_constant_uint_array_pointer()
996 {
997 uint32_t offset = ir.increase_bound_by(3);
998 uint32_t type_ptr_id = offset;
999 uint32_t type_ptr_ptr_id = offset + 1;
1000 uint32_t var_id = offset + 2;
1001
1002 // Create a buffer to hold extra data, including the swizzle constants.
1003 SPIRType uint_type_pointer = get_uint_type();
1004 uint_type_pointer.pointer = true;
1005 uint_type_pointer.pointer_depth++;
1006 uint_type_pointer.parent_type = get_uint_type_id();
1007 uint_type_pointer.storage = StorageClassUniform;
1008 set<SPIRType>(type_ptr_id, uint_type_pointer);
1009 set_decoration(type_ptr_id, DecorationArrayStride, 4);
1010
1011 SPIRType uint_type_pointer2 = uint_type_pointer;
1012 uint_type_pointer2.pointer_depth++;
1013 uint_type_pointer2.parent_type = type_ptr_id;
1014 set<SPIRType>(type_ptr_ptr_id, uint_type_pointer2);
1015
1016 set<SPIRVariable>(var_id, type_ptr_ptr_id, StorageClassUniformConstant);
1017 return var_id;
1018 }
1019
create_sampler_address(const char * prefix,MSLSamplerAddress addr)1020 static string create_sampler_address(const char *prefix, MSLSamplerAddress addr)
1021 {
1022 switch (addr)
1023 {
1024 case MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE:
1025 return join(prefix, "address::clamp_to_edge");
1026 case MSL_SAMPLER_ADDRESS_CLAMP_TO_ZERO:
1027 return join(prefix, "address::clamp_to_zero");
1028 case MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER:
1029 return join(prefix, "address::clamp_to_border");
1030 case MSL_SAMPLER_ADDRESS_REPEAT:
1031 return join(prefix, "address::repeat");
1032 case MSL_SAMPLER_ADDRESS_MIRRORED_REPEAT:
1033 return join(prefix, "address::mirrored_repeat");
1034 default:
1035 SPIRV_CROSS_THROW("Invalid sampler addressing mode.");
1036 }
1037 }
1038
get_stage_in_struct_type()1039 SPIRType &CompilerMSL::get_stage_in_struct_type()
1040 {
1041 auto &si_var = get<SPIRVariable>(stage_in_var_id);
1042 return get_variable_data_type(si_var);
1043 }
1044
get_stage_out_struct_type()1045 SPIRType &CompilerMSL::get_stage_out_struct_type()
1046 {
1047 auto &so_var = get<SPIRVariable>(stage_out_var_id);
1048 return get_variable_data_type(so_var);
1049 }
1050
get_patch_stage_in_struct_type()1051 SPIRType &CompilerMSL::get_patch_stage_in_struct_type()
1052 {
1053 auto &si_var = get<SPIRVariable>(patch_stage_in_var_id);
1054 return get_variable_data_type(si_var);
1055 }
1056
get_patch_stage_out_struct_type()1057 SPIRType &CompilerMSL::get_patch_stage_out_struct_type()
1058 {
1059 auto &so_var = get<SPIRVariable>(patch_stage_out_var_id);
1060 return get_variable_data_type(so_var);
1061 }
1062
get_tess_factor_struct_name()1063 std::string CompilerMSL::get_tess_factor_struct_name()
1064 {
1065 if (get_entry_point().flags.get(ExecutionModeTriangles))
1066 return "MTLTriangleTessellationFactorsHalf";
1067 return "MTLQuadTessellationFactorsHalf";
1068 }
1069
get_uint_type()1070 SPIRType &CompilerMSL::get_uint_type()
1071 {
1072 return get<SPIRType>(get_uint_type_id());
1073 }
1074
get_uint_type_id()1075 uint32_t CompilerMSL::get_uint_type_id()
1076 {
1077 if (uint_type_id != 0)
1078 return uint_type_id;
1079
1080 uint_type_id = ir.increase_bound_by(1);
1081
1082 SPIRType type;
1083 type.basetype = SPIRType::UInt;
1084 type.width = 32;
1085 set<SPIRType>(uint_type_id, type);
1086 return uint_type_id;
1087 }
1088
emit_entry_point_declarations()1089 void CompilerMSL::emit_entry_point_declarations()
1090 {
1091 // FIXME: Get test coverage here ...
1092 // Constant arrays of non-primitive types (i.e. matrices) won't link properly into Metal libraries
1093 declare_complex_constant_arrays();
1094
1095 // Emit constexpr samplers here.
1096 for (auto &samp : constexpr_samplers_by_id)
1097 {
1098 auto &var = get<SPIRVariable>(samp.first);
1099 auto &type = get<SPIRType>(var.basetype);
1100 if (type.basetype == SPIRType::Sampler)
1101 add_resource_name(samp.first);
1102
1103 SmallVector<string> args;
1104 auto &s = samp.second;
1105
1106 if (s.coord != MSL_SAMPLER_COORD_NORMALIZED)
1107 args.push_back("coord::pixel");
1108
1109 if (s.min_filter == s.mag_filter)
1110 {
1111 if (s.min_filter != MSL_SAMPLER_FILTER_NEAREST)
1112 args.push_back("filter::linear");
1113 }
1114 else
1115 {
1116 if (s.min_filter != MSL_SAMPLER_FILTER_NEAREST)
1117 args.push_back("min_filter::linear");
1118 if (s.mag_filter != MSL_SAMPLER_FILTER_NEAREST)
1119 args.push_back("mag_filter::linear");
1120 }
1121
1122 switch (s.mip_filter)
1123 {
1124 case MSL_SAMPLER_MIP_FILTER_NONE:
1125 // Default
1126 break;
1127 case MSL_SAMPLER_MIP_FILTER_NEAREST:
1128 args.push_back("mip_filter::nearest");
1129 break;
1130 case MSL_SAMPLER_MIP_FILTER_LINEAR:
1131 args.push_back("mip_filter::linear");
1132 break;
1133 default:
1134 SPIRV_CROSS_THROW("Invalid mip filter.");
1135 }
1136
1137 if (s.s_address == s.t_address && s.s_address == s.r_address)
1138 {
1139 if (s.s_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
1140 args.push_back(create_sampler_address("", s.s_address));
1141 }
1142 else
1143 {
1144 if (s.s_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
1145 args.push_back(create_sampler_address("s_", s.s_address));
1146 if (s.t_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
1147 args.push_back(create_sampler_address("t_", s.t_address));
1148 if (s.r_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
1149 args.push_back(create_sampler_address("r_", s.r_address));
1150 }
1151
1152 if (s.compare_enable)
1153 {
1154 switch (s.compare_func)
1155 {
1156 case MSL_SAMPLER_COMPARE_FUNC_ALWAYS:
1157 args.push_back("compare_func::always");
1158 break;
1159 case MSL_SAMPLER_COMPARE_FUNC_NEVER:
1160 args.push_back("compare_func::never");
1161 break;
1162 case MSL_SAMPLER_COMPARE_FUNC_EQUAL:
1163 args.push_back("compare_func::equal");
1164 break;
1165 case MSL_SAMPLER_COMPARE_FUNC_NOT_EQUAL:
1166 args.push_back("compare_func::not_equal");
1167 break;
1168 case MSL_SAMPLER_COMPARE_FUNC_LESS:
1169 args.push_back("compare_func::less");
1170 break;
1171 case MSL_SAMPLER_COMPARE_FUNC_LESS_EQUAL:
1172 args.push_back("compare_func::less_equal");
1173 break;
1174 case MSL_SAMPLER_COMPARE_FUNC_GREATER:
1175 args.push_back("compare_func::greater");
1176 break;
1177 case MSL_SAMPLER_COMPARE_FUNC_GREATER_EQUAL:
1178 args.push_back("compare_func::greater_equal");
1179 break;
1180 default:
1181 SPIRV_CROSS_THROW("Invalid sampler compare function.");
1182 }
1183 }
1184
1185 if (s.s_address == MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER || s.t_address == MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER ||
1186 s.r_address == MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER)
1187 {
1188 switch (s.border_color)
1189 {
1190 case MSL_SAMPLER_BORDER_COLOR_OPAQUE_BLACK:
1191 args.push_back("border_color::opaque_black");
1192 break;
1193 case MSL_SAMPLER_BORDER_COLOR_OPAQUE_WHITE:
1194 args.push_back("border_color::opaque_white");
1195 break;
1196 case MSL_SAMPLER_BORDER_COLOR_TRANSPARENT_BLACK:
1197 args.push_back("border_color::transparent_black");
1198 break;
1199 default:
1200 SPIRV_CROSS_THROW("Invalid sampler border color.");
1201 }
1202 }
1203
1204 if (s.anisotropy_enable)
1205 args.push_back(join("max_anisotropy(", s.max_anisotropy, ")"));
1206 if (s.lod_clamp_enable)
1207 {
1208 args.push_back(join("lod_clamp(", convert_to_string(s.lod_clamp_min, current_locale_radix_character), ", ",
1209 convert_to_string(s.lod_clamp_max, current_locale_radix_character), ")"));
1210 }
1211
1212 // If we would emit no arguments, then omit the parentheses entirely. Otherwise,
1213 // we'll wind up with a "most vexing parse" situation.
1214 if (args.empty())
1215 statement("constexpr sampler ",
1216 type.basetype == SPIRType::SampledImage ? to_sampler_expression(samp.first) : to_name(samp.first),
1217 ";");
1218 else
1219 statement("constexpr sampler ",
1220 type.basetype == SPIRType::SampledImage ? to_sampler_expression(samp.first) : to_name(samp.first),
1221 "(", merge(args), ");");
1222 }
1223
1224 // Emit dynamic buffers here.
1225 for (auto &dynamic_buffer : buffers_requiring_dynamic_offset)
1226 {
1227 if (!dynamic_buffer.second.second)
1228 {
1229 // Could happen if no buffer was used at requested binding point.
1230 continue;
1231 }
1232
1233 const auto &var = get<SPIRVariable>(dynamic_buffer.second.second);
1234 uint32_t var_id = var.self;
1235 const auto &type = get_variable_data_type(var);
1236 string name = to_name(var.self);
1237 uint32_t desc_set = get_decoration(var.self, DecorationDescriptorSet);
1238 uint32_t arg_id = argument_buffer_ids[desc_set];
1239 uint32_t base_index = dynamic_buffer.second.first;
1240
1241 if (!type.array.empty())
1242 {
1243 // This is complicated, because we need to support arrays of arrays.
1244 // And it's even worse if the outermost dimension is a runtime array, because now
1245 // all this complicated goop has to go into the shader itself. (FIXME)
1246 if (!type.array[type.array.size() - 1])
1247 SPIRV_CROSS_THROW("Runtime arrays with dynamic offsets are not supported yet.");
1248 else
1249 {
1250 is_using_builtin_array = true;
1251 statement(get_argument_address_space(var), " ", type_to_glsl(type), "* ", to_restrict(var_id), name,
1252 type_to_array_glsl(type), " =");
1253
1254 uint32_t dim = uint32_t(type.array.size());
1255 uint32_t j = 0;
1256 for (SmallVector<uint32_t> indices(type.array.size());
1257 indices[type.array.size() - 1] < to_array_size_literal(type); j++)
1258 {
1259 while (dim > 0)
1260 {
1261 begin_scope();
1262 --dim;
1263 }
1264
1265 string arrays;
1266 for (uint32_t i = uint32_t(type.array.size()); i; --i)
1267 arrays += join("[", indices[i - 1], "]");
1268 statement("(", get_argument_address_space(var), " ", type_to_glsl(type), "* ",
1269 to_restrict(var_id, false), ")((", get_argument_address_space(var), " char* ",
1270 to_restrict(var_id, false), ")", to_name(arg_id), ".", ensure_valid_name(name, "m"),
1271 arrays, " + ", to_name(dynamic_offsets_buffer_id), "[", base_index + j, "]),");
1272
1273 while (++indices[dim] >= to_array_size_literal(type, dim) && dim < type.array.size() - 1)
1274 {
1275 end_scope(",");
1276 indices[dim++] = 0;
1277 }
1278 }
1279 end_scope_decl();
1280 statement_no_indent("");
1281 is_using_builtin_array = false;
1282 }
1283 }
1284 else
1285 {
1286 statement(get_argument_address_space(var), " auto& ", to_restrict(var_id), name, " = *(",
1287 get_argument_address_space(var), " ", type_to_glsl(type), "* ", to_restrict(var_id, false), ")((",
1288 get_argument_address_space(var), " char* ", to_restrict(var_id, false), ")", to_name(arg_id), ".",
1289 ensure_valid_name(name, "m"), " + ", to_name(dynamic_offsets_buffer_id), "[", base_index, "]);");
1290 }
1291 }
1292
1293 // Emit buffer arrays here.
1294 for (uint32_t array_id : buffer_arrays)
1295 {
1296 const auto &var = get<SPIRVariable>(array_id);
1297 const auto &type = get_variable_data_type(var);
1298 const auto &buffer_type = get_variable_element_type(var);
1299 string name = to_name(array_id);
1300 statement(get_argument_address_space(var), " ", type_to_glsl(buffer_type), "* ", to_restrict(array_id), name,
1301 "[] =");
1302 begin_scope();
1303 for (uint32_t i = 0; i < to_array_size_literal(type); ++i)
1304 statement(name, "_", i, ",");
1305 end_scope_decl();
1306 statement_no_indent("");
1307 }
1308 // For some reason, without this, we end up emitting the arrays twice.
1309 buffer_arrays.clear();
1310
1311 // Emit disabled fragment outputs.
1312 std::sort(disabled_frag_outputs.begin(), disabled_frag_outputs.end());
1313 for (uint32_t var_id : disabled_frag_outputs)
1314 {
1315 auto &var = get<SPIRVariable>(var_id);
1316 add_local_variable_name(var_id);
1317 statement(variable_decl(var), ";");
1318 var.deferred_declaration = false;
1319 }
1320 }
1321
compile()1322 string CompilerMSL::compile()
1323 {
1324 replace_illegal_entry_point_names();
1325 ir.fixup_reserved_names();
1326
1327 // Do not deal with GLES-isms like precision, older extensions and such.
1328 options.vulkan_semantics = true;
1329 options.es = false;
1330 options.version = 450;
1331 backend.null_pointer_literal = "nullptr";
1332 backend.float_literal_suffix = false;
1333 backend.uint32_t_literal_suffix = true;
1334 backend.int16_t_literal_suffix = "";
1335 backend.uint16_t_literal_suffix = "";
1336 backend.basic_int_type = "int";
1337 backend.basic_uint_type = "uint";
1338 backend.basic_int8_type = "char";
1339 backend.basic_uint8_type = "uchar";
1340 backend.basic_int16_type = "short";
1341 backend.basic_uint16_type = "ushort";
1342 backend.discard_literal = "discard_fragment()";
1343 backend.demote_literal = "discard_fragment()";
1344 backend.boolean_mix_function = "select";
1345 backend.swizzle_is_function = false;
1346 backend.shared_is_implied = false;
1347 backend.use_initializer_list = true;
1348 backend.use_typed_initializer_list = true;
1349 backend.native_row_major_matrix = false;
1350 backend.unsized_array_supported = false;
1351 backend.can_declare_arrays_inline = false;
1352 backend.allow_truncated_access_chain = true;
1353 backend.comparison_image_samples_scalar = true;
1354 backend.native_pointers = true;
1355 backend.nonuniform_qualifier = "";
1356 backend.support_small_type_sampling_result = true;
1357 backend.supports_empty_struct = true;
1358
1359 // Allow Metal to use the array<T> template unless we force it off.
1360 backend.can_return_array = !msl_options.force_native_arrays;
1361 backend.array_is_value_type = !msl_options.force_native_arrays;
1362 // Arrays which are part of buffer objects are never considered to be native arrays.
1363 backend.buffer_offset_array_is_value_type = false;
1364 backend.support_pointer_to_pointer = true;
1365
1366 capture_output_to_buffer = msl_options.capture_output_to_buffer;
1367 is_rasterization_disabled = msl_options.disable_rasterization || capture_output_to_buffer;
1368
1369 // Initialize array here rather than constructor, MSVC 2013 workaround.
1370 for (auto &id : next_metal_resource_ids)
1371 id = 0;
1372
1373 fixup_type_alias();
1374 replace_illegal_names();
1375 sync_entry_point_aliases_and_names();
1376
1377 build_function_control_flow_graphs_and_analyze();
1378 update_active_builtins();
1379 analyze_image_and_sampler_usage();
1380 analyze_sampled_image_usage();
1381 analyze_interlocked_resource_usage();
1382 preprocess_op_codes();
1383 build_implicit_builtins();
1384
1385 fixup_image_load_store_access();
1386
1387 set_enabled_interface_variables(get_active_interface_variables());
1388 if (msl_options.force_active_argument_buffer_resources)
1389 activate_argument_buffer_resources();
1390
1391 if (swizzle_buffer_id)
1392 active_interface_variables.insert(swizzle_buffer_id);
1393 if (buffer_size_buffer_id)
1394 active_interface_variables.insert(buffer_size_buffer_id);
1395 if (view_mask_buffer_id)
1396 active_interface_variables.insert(view_mask_buffer_id);
1397 if (dynamic_offsets_buffer_id)
1398 active_interface_variables.insert(dynamic_offsets_buffer_id);
1399 if (builtin_layer_id)
1400 active_interface_variables.insert(builtin_layer_id);
1401 if (builtin_dispatch_base_id && !msl_options.supports_msl_version(1, 2))
1402 active_interface_variables.insert(builtin_dispatch_base_id);
1403 if (builtin_sample_mask_id)
1404 active_interface_variables.insert(builtin_sample_mask_id);
1405
1406 // Create structs to hold input, output and uniform variables.
1407 // Do output first to ensure out. is declared at top of entry function.
1408 qual_pos_var_name = "";
1409 stage_out_var_id = add_interface_block(StorageClassOutput);
1410 patch_stage_out_var_id = add_interface_block(StorageClassOutput, true);
1411 stage_in_var_id = add_interface_block(StorageClassInput);
1412 if (get_execution_model() == ExecutionModelTessellationEvaluation)
1413 patch_stage_in_var_id = add_interface_block(StorageClassInput, true);
1414
1415 if (get_execution_model() == ExecutionModelTessellationControl)
1416 stage_out_ptr_var_id = add_interface_block_pointer(stage_out_var_id, StorageClassOutput);
1417 if (is_tessellation_shader())
1418 stage_in_ptr_var_id = add_interface_block_pointer(stage_in_var_id, StorageClassInput);
1419
1420 // Metal vertex functions that define no output must disable rasterization and return void.
1421 if (!stage_out_var_id)
1422 is_rasterization_disabled = true;
1423
1424 // Convert the use of global variables to recursively-passed function parameters
1425 localize_global_variables();
1426 extract_global_variables_from_functions();
1427
1428 // Mark any non-stage-in structs to be tightly packed.
1429 mark_packable_structs();
1430 reorder_type_alias();
1431
1432 // Add fixup hooks required by shader inputs and outputs. This needs to happen before
1433 // the loop, so the hooks aren't added multiple times.
1434 fix_up_shader_inputs_outputs();
1435
1436 // If we are using argument buffers, we create argument buffer structures for them here.
1437 // These buffers will be used in the entry point, not the individual resources.
1438 if (msl_options.argument_buffers)
1439 {
1440 if (!msl_options.supports_msl_version(2, 0))
1441 SPIRV_CROSS_THROW("Argument buffers can only be used with MSL 2.0 and up.");
1442 analyze_argument_buffers();
1443 }
1444
1445 uint32_t pass_count = 0;
1446 do
1447 {
1448 if (pass_count >= 3)
1449 SPIRV_CROSS_THROW("Over 3 compilation loops detected. Must be a bug!");
1450
1451 reset();
1452
1453 // Start bindings at zero.
1454 next_metal_resource_index_buffer = 0;
1455 next_metal_resource_index_texture = 0;
1456 next_metal_resource_index_sampler = 0;
1457 for (auto &id : next_metal_resource_ids)
1458 id = 0;
1459
1460 // Move constructor for this type is broken on GCC 4.9 ...
1461 buffer.reset();
1462
1463 emit_header();
1464 emit_custom_templates();
1465 emit_specialization_constants_and_structs();
1466 emit_resources();
1467 emit_custom_functions();
1468 emit_function(get<SPIRFunction>(ir.default_entry_point), Bitset());
1469
1470 pass_count++;
1471 } while (is_forcing_recompilation());
1472
1473 return buffer.str();
1474 }
1475
1476 // Register the need to output any custom functions.
preprocess_op_codes()1477 void CompilerMSL::preprocess_op_codes()
1478 {
1479 OpCodePreprocessor preproc(*this);
1480 traverse_all_reachable_opcodes(get<SPIRFunction>(ir.default_entry_point), preproc);
1481
1482 suppress_missing_prototypes = preproc.suppress_missing_prototypes;
1483
1484 if (preproc.uses_atomics)
1485 {
1486 add_header_line("#include <metal_atomic>");
1487 add_pragma_line("#pragma clang diagnostic ignored \"-Wunused-variable\"");
1488 }
1489
1490 // Before MSL 2.1 (2.2 for textures), Metal vertex functions that write to
1491 // resources must disable rasterization and return void.
1492 if (preproc.uses_resource_write)
1493 is_rasterization_disabled = true;
1494
1495 // Tessellation control shaders are run as compute functions in Metal, and so
1496 // must capture their output to a buffer.
1497 if (get_execution_model() == ExecutionModelTessellationControl ||
1498 (get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation))
1499 {
1500 is_rasterization_disabled = true;
1501 capture_output_to_buffer = true;
1502 }
1503
1504 if (preproc.needs_subgroup_invocation_id)
1505 needs_subgroup_invocation_id = true;
1506 if (preproc.needs_subgroup_size)
1507 needs_subgroup_size = true;
1508 // build_implicit_builtins() hasn't run yet, and in fact, this needs to execute
1509 // before then so that gl_SampleID will get added; so we also need to check if
1510 // that function would add gl_FragCoord.
1511 if (preproc.needs_sample_id || msl_options.force_sample_rate_shading ||
1512 (is_sample_rate() && (active_input_builtins.get(BuiltInFragCoord) ||
1513 (need_subpass_input && !msl_options.use_framebuffer_fetch_subpasses))))
1514 needs_sample_id = true;
1515 }
1516
1517 // Move the Private and Workgroup global variables to the entry function.
1518 // Non-constant variables cannot have global scope in Metal.
localize_global_variables()1519 void CompilerMSL::localize_global_variables()
1520 {
1521 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
1522 auto iter = global_variables.begin();
1523 while (iter != global_variables.end())
1524 {
1525 uint32_t v_id = *iter;
1526 auto &var = get<SPIRVariable>(v_id);
1527 if (var.storage == StorageClassPrivate || var.storage == StorageClassWorkgroup)
1528 {
1529 if (!variable_is_lut(var))
1530 entry_func.add_local_variable(v_id);
1531 iter = global_variables.erase(iter);
1532 }
1533 else
1534 iter++;
1535 }
1536 }
1537
1538 // For any global variable accessed directly by a function,
1539 // extract that variable and add it as an argument to that function.
extract_global_variables_from_functions()1540 void CompilerMSL::extract_global_variables_from_functions()
1541 {
1542 // Uniforms
1543 unordered_set<uint32_t> global_var_ids;
1544 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1545 if (var.storage == StorageClassInput || var.storage == StorageClassOutput ||
1546 var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant ||
1547 var.storage == StorageClassPushConstant || var.storage == StorageClassStorageBuffer)
1548 {
1549 global_var_ids.insert(var.self);
1550 }
1551 });
1552
1553 // Local vars that are declared in the main function and accessed directly by a function
1554 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
1555 for (auto &var : entry_func.local_variables)
1556 if (get<SPIRVariable>(var).storage != StorageClassFunction)
1557 global_var_ids.insert(var);
1558
1559 std::set<uint32_t> added_arg_ids;
1560 unordered_set<uint32_t> processed_func_ids;
1561 extract_global_variables_from_function(ir.default_entry_point, added_arg_ids, global_var_ids, processed_func_ids);
1562 }
1563
1564 // MSL does not support the use of global variables for shader input content.
1565 // For any global variable accessed directly by the specified function, extract that variable,
1566 // add it as an argument to that function, and the arg to the added_arg_ids collection.
extract_global_variables_from_function(uint32_t func_id,std::set<uint32_t> & added_arg_ids,unordered_set<uint32_t> & global_var_ids,unordered_set<uint32_t> & processed_func_ids)1567 void CompilerMSL::extract_global_variables_from_function(uint32_t func_id, std::set<uint32_t> &added_arg_ids,
1568 unordered_set<uint32_t> &global_var_ids,
1569 unordered_set<uint32_t> &processed_func_ids)
1570 {
1571 // Avoid processing a function more than once
1572 if (processed_func_ids.find(func_id) != processed_func_ids.end())
1573 {
1574 // Return function global variables
1575 added_arg_ids = function_global_vars[func_id];
1576 return;
1577 }
1578
1579 processed_func_ids.insert(func_id);
1580
1581 auto &func = get<SPIRFunction>(func_id);
1582
1583 // Recursively establish global args added to functions on which we depend.
1584 for (auto block : func.blocks)
1585 {
1586 auto &b = get<SPIRBlock>(block);
1587 for (auto &i : b.ops)
1588 {
1589 auto ops = stream(i);
1590 auto op = static_cast<Op>(i.op);
1591
1592 switch (op)
1593 {
1594 case OpLoad:
1595 case OpInBoundsAccessChain:
1596 case OpAccessChain:
1597 case OpPtrAccessChain:
1598 case OpArrayLength:
1599 {
1600 uint32_t base_id = ops[2];
1601 if (global_var_ids.find(base_id) != global_var_ids.end())
1602 added_arg_ids.insert(base_id);
1603
1604 // Use Metal's native frame-buffer fetch API for subpass inputs.
1605 auto &type = get<SPIRType>(ops[0]);
1606 if (type.basetype == SPIRType::Image && type.image.dim == DimSubpassData &&
1607 (!msl_options.use_framebuffer_fetch_subpasses))
1608 {
1609 // Implicitly reads gl_FragCoord.
1610 assert(builtin_frag_coord_id != 0);
1611 added_arg_ids.insert(builtin_frag_coord_id);
1612 if (msl_options.multiview)
1613 {
1614 // Implicitly reads gl_ViewIndex.
1615 assert(builtin_view_idx_id != 0);
1616 added_arg_ids.insert(builtin_view_idx_id);
1617 }
1618 else if (msl_options.arrayed_subpass_input)
1619 {
1620 // Implicitly reads gl_Layer.
1621 assert(builtin_layer_id != 0);
1622 added_arg_ids.insert(builtin_layer_id);
1623 }
1624 }
1625
1626 break;
1627 }
1628
1629 case OpFunctionCall:
1630 {
1631 // First see if any of the function call args are globals
1632 for (uint32_t arg_idx = 3; arg_idx < i.length; arg_idx++)
1633 {
1634 uint32_t arg_id = ops[arg_idx];
1635 if (global_var_ids.find(arg_id) != global_var_ids.end())
1636 added_arg_ids.insert(arg_id);
1637 }
1638
1639 // Then recurse into the function itself to extract globals used internally in the function
1640 uint32_t inner_func_id = ops[2];
1641 std::set<uint32_t> inner_func_args;
1642 extract_global_variables_from_function(inner_func_id, inner_func_args, global_var_ids,
1643 processed_func_ids);
1644 added_arg_ids.insert(inner_func_args.begin(), inner_func_args.end());
1645 break;
1646 }
1647
1648 case OpStore:
1649 {
1650 uint32_t base_id = ops[0];
1651 if (global_var_ids.find(base_id) != global_var_ids.end())
1652 added_arg_ids.insert(base_id);
1653
1654 uint32_t rvalue_id = ops[1];
1655 if (global_var_ids.find(rvalue_id) != global_var_ids.end())
1656 added_arg_ids.insert(rvalue_id);
1657
1658 break;
1659 }
1660
1661 case OpSelect:
1662 {
1663 uint32_t base_id = ops[3];
1664 if (global_var_ids.find(base_id) != global_var_ids.end())
1665 added_arg_ids.insert(base_id);
1666 base_id = ops[4];
1667 if (global_var_ids.find(base_id) != global_var_ids.end())
1668 added_arg_ids.insert(base_id);
1669 break;
1670 }
1671
1672 // Emulate texture2D atomic operations
1673 case OpImageTexelPointer:
1674 {
1675 // When using the pointer, we need to know which variable it is actually loaded from.
1676 uint32_t base_id = ops[2];
1677 auto *var = maybe_get_backing_variable(base_id);
1678 if (var && atomic_image_vars.count(var->self))
1679 {
1680 if (global_var_ids.find(base_id) != global_var_ids.end())
1681 added_arg_ids.insert(base_id);
1682 }
1683 break;
1684 }
1685
1686 case OpExtInst:
1687 {
1688 uint32_t extension_set = ops[2];
1689 if (get<SPIRExtension>(extension_set).ext == SPIRExtension::GLSL)
1690 {
1691 auto op_450 = static_cast<GLSLstd450>(ops[3]);
1692 switch (op_450)
1693 {
1694 case GLSLstd450InterpolateAtCentroid:
1695 case GLSLstd450InterpolateAtSample:
1696 case GLSLstd450InterpolateAtOffset:
1697 {
1698 // For these, we really need the stage-in block. It is theoretically possible to pass the
1699 // interpolant object, but a) doing so would require us to create an entirely new variable
1700 // with Interpolant type, and b) if we have a struct or array, handling all the members and
1701 // elements could get unwieldy fast.
1702 added_arg_ids.insert(stage_in_var_id);
1703 break;
1704 }
1705 default:
1706 break;
1707 }
1708 }
1709 break;
1710 }
1711
1712 case OpGroupNonUniformInverseBallot:
1713 {
1714 added_arg_ids.insert(builtin_subgroup_invocation_id_id);
1715 break;
1716 }
1717
1718 case OpGroupNonUniformBallotFindLSB:
1719 case OpGroupNonUniformBallotFindMSB:
1720 {
1721 added_arg_ids.insert(builtin_subgroup_size_id);
1722 break;
1723 }
1724
1725 case OpGroupNonUniformBallotBitCount:
1726 {
1727 auto operation = static_cast<GroupOperation>(ops[3]);
1728 switch (operation)
1729 {
1730 case GroupOperationReduce:
1731 added_arg_ids.insert(builtin_subgroup_size_id);
1732 break;
1733 case GroupOperationInclusiveScan:
1734 case GroupOperationExclusiveScan:
1735 added_arg_ids.insert(builtin_subgroup_invocation_id_id);
1736 break;
1737 default:
1738 break;
1739 }
1740 break;
1741 }
1742
1743 default:
1744 break;
1745 }
1746
1747 // TODO: Add all other operations which can affect memory.
1748 // We should consider a more unified system here to reduce boiler-plate.
1749 // This kind of analysis is done in several places ...
1750 }
1751 }
1752
1753 function_global_vars[func_id] = added_arg_ids;
1754
1755 // Add the global variables as arguments to the function
1756 if (func_id != ir.default_entry_point)
1757 {
1758 bool control_point_added_in = false;
1759 bool control_point_added_out = false;
1760 bool patch_added_in = false;
1761 bool patch_added_out = false;
1762
1763 for (uint32_t arg_id : added_arg_ids)
1764 {
1765 auto &var = get<SPIRVariable>(arg_id);
1766 uint32_t type_id = var.basetype;
1767 auto *p_type = &get<SPIRType>(type_id);
1768 BuiltIn bi_type = BuiltIn(get_decoration(arg_id, DecorationBuiltIn));
1769
1770 bool is_patch = has_decoration(arg_id, DecorationPatch) || is_patch_block(*p_type);
1771 bool is_block = has_decoration(p_type->self, DecorationBlock);
1772 bool is_control_point_storage =
1773 !is_patch &&
1774 ((is_tessellation_shader() && var.storage == StorageClassInput) ||
1775 (get_execution_model() == ExecutionModelTessellationControl && var.storage == StorageClassOutput));
1776 bool is_patch_block_storage = is_patch && is_block && var.storage == StorageClassOutput;
1777 bool is_builtin = is_builtin_variable(var);
1778 bool variable_is_stage_io =
1779 !is_builtin || bi_type == BuiltInPosition || bi_type == BuiltInPointSize ||
1780 bi_type == BuiltInClipDistance || bi_type == BuiltInCullDistance ||
1781 p_type->basetype == SPIRType::Struct;
1782 bool is_redirected_to_global_stage_io = (is_control_point_storage || is_patch_block_storage) &&
1783 variable_is_stage_io;
1784
1785 // If output is masked it is not considered part of the global stage IO interface.
1786 if (is_redirected_to_global_stage_io && var.storage == StorageClassOutput)
1787 is_redirected_to_global_stage_io = !is_stage_output_variable_masked(var);
1788
1789 if (is_redirected_to_global_stage_io)
1790 {
1791 // Tessellation control shaders see inputs and per-vertex outputs as arrays.
1792 // Similarly, tessellation evaluation shaders see per-vertex inputs as arrays.
1793 // We collected them into a structure; we must pass the array of this
1794 // structure to the function.
1795 std::string name;
1796 if (is_patch)
1797 name = var.storage == StorageClassInput ? patch_stage_in_var_name : patch_stage_out_var_name;
1798 else
1799 name = var.storage == StorageClassInput ? "gl_in" : "gl_out";
1800
1801 if (var.storage == StorageClassOutput && has_decoration(p_type->self, DecorationBlock))
1802 {
1803 // If we're redirecting a block, we might still need to access the original block
1804 // variable if we're masking some members.
1805 for (uint32_t mbr_idx = 0; mbr_idx < uint32_t(p_type->member_types.size()); mbr_idx++)
1806 {
1807 if (is_stage_output_block_member_masked(var, mbr_idx, true))
1808 {
1809 func.add_parameter(var.basetype, var.self, true);
1810 break;
1811 }
1812 }
1813 }
1814
1815 // Tessellation control shaders see inputs and per-vertex outputs as arrays.
1816 // Similarly, tessellation evaluation shaders see per-vertex inputs as arrays.
1817 // We collected them into a structure; we must pass the array of this
1818 // structure to the function.
1819 if (var.storage == StorageClassInput)
1820 {
1821 auto &added_in = is_patch ? patch_added_in : control_point_added_in;
1822 if (added_in)
1823 continue;
1824 arg_id = is_patch ? patch_stage_in_var_id : stage_in_ptr_var_id;
1825 added_in = true;
1826 }
1827 else if (var.storage == StorageClassOutput)
1828 {
1829 auto &added_out = is_patch ? patch_added_out : control_point_added_out;
1830 if (added_out)
1831 continue;
1832 arg_id = is_patch ? patch_stage_out_var_id : stage_out_ptr_var_id;
1833 added_out = true;
1834 }
1835
1836 type_id = get<SPIRVariable>(arg_id).basetype;
1837 uint32_t next_id = ir.increase_bound_by(1);
1838 func.add_parameter(type_id, next_id, true);
1839 set<SPIRVariable>(next_id, type_id, StorageClassFunction, 0, arg_id);
1840
1841 set_name(next_id, name);
1842 }
1843 else if (is_builtin && has_decoration(p_type->self, DecorationBlock))
1844 {
1845 // Get the pointee type
1846 type_id = get_pointee_type_id(type_id);
1847 p_type = &get<SPIRType>(type_id);
1848
1849 uint32_t mbr_idx = 0;
1850 for (auto &mbr_type_id : p_type->member_types)
1851 {
1852 BuiltIn builtin = BuiltInMax;
1853 is_builtin = is_member_builtin(*p_type, mbr_idx, &builtin);
1854 if (is_builtin && has_active_builtin(builtin, var.storage))
1855 {
1856 // Add a arg variable with the same type and decorations as the member
1857 uint32_t next_ids = ir.increase_bound_by(2);
1858 uint32_t ptr_type_id = next_ids + 0;
1859 uint32_t var_id = next_ids + 1;
1860
1861 // Make sure we have an actual pointer type,
1862 // so that we will get the appropriate address space when declaring these builtins.
1863 auto &ptr = set<SPIRType>(ptr_type_id, get<SPIRType>(mbr_type_id));
1864 ptr.self = mbr_type_id;
1865 ptr.storage = var.storage;
1866 ptr.pointer = true;
1867 ptr.pointer_depth++;
1868 ptr.parent_type = mbr_type_id;
1869
1870 func.add_parameter(mbr_type_id, var_id, true);
1871 set<SPIRVariable>(var_id, ptr_type_id, StorageClassFunction);
1872 ir.meta[var_id].decoration = ir.meta[type_id].members[mbr_idx];
1873 }
1874 mbr_idx++;
1875 }
1876 }
1877 else
1878 {
1879 uint32_t next_id = ir.increase_bound_by(1);
1880 func.add_parameter(type_id, next_id, true);
1881 set<SPIRVariable>(next_id, type_id, StorageClassFunction, 0, arg_id);
1882
1883 // Ensure the existing variable has a valid name and the new variable has all the same meta info
1884 set_name(arg_id, ensure_valid_name(to_name(arg_id), "v"));
1885 ir.meta[next_id] = ir.meta[arg_id];
1886 }
1887 }
1888 }
1889 }
1890
1891 // For all variables that are some form of non-input-output interface block, mark that all the structs
1892 // that are recursively contained within the type referenced by that variable should be packed tightly.
mark_packable_structs()1893 void CompilerMSL::mark_packable_structs()
1894 {
1895 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1896 if (var.storage != StorageClassFunction && !is_hidden_variable(var))
1897 {
1898 auto &type = this->get<SPIRType>(var.basetype);
1899 if (type.pointer &&
1900 (type.storage == StorageClassUniform || type.storage == StorageClassUniformConstant ||
1901 type.storage == StorageClassPushConstant || type.storage == StorageClassStorageBuffer) &&
1902 (has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock)))
1903 mark_as_packable(type);
1904 }
1905 });
1906 }
1907
1908 // If the specified type is a struct, it and any nested structs
1909 // are marked as packable with the SPIRVCrossDecorationBufferBlockRepacked decoration,
mark_as_packable(SPIRType & type)1910 void CompilerMSL::mark_as_packable(SPIRType &type)
1911 {
1912 // If this is not the base type (eg. it's a pointer or array), tunnel down
1913 if (type.parent_type)
1914 {
1915 mark_as_packable(get<SPIRType>(type.parent_type));
1916 return;
1917 }
1918
1919 if (type.basetype == SPIRType::Struct)
1920 {
1921 set_extended_decoration(type.self, SPIRVCrossDecorationBufferBlockRepacked);
1922
1923 // Recurse
1924 uint32_t mbr_cnt = uint32_t(type.member_types.size());
1925 for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
1926 {
1927 uint32_t mbr_type_id = type.member_types[mbr_idx];
1928 auto &mbr_type = get<SPIRType>(mbr_type_id);
1929 mark_as_packable(mbr_type);
1930 if (mbr_type.type_alias)
1931 {
1932 auto &mbr_type_alias = get<SPIRType>(mbr_type.type_alias);
1933 mark_as_packable(mbr_type_alias);
1934 }
1935 }
1936 }
1937 }
1938
1939 // If a shader input exists at the location, it is marked as being used by this shader
mark_location_as_used_by_shader(uint32_t location,const SPIRType & type,StorageClass storage,bool fallback)1940 void CompilerMSL::mark_location_as_used_by_shader(uint32_t location, const SPIRType &type,
1941 StorageClass storage, bool fallback)
1942 {
1943 if (storage != StorageClassInput)
1944 return;
1945
1946 uint32_t count = type_to_location_count(type);
1947 for (uint32_t i = 0; i < count; i++)
1948 {
1949 location_inputs_in_use.insert(location + i);
1950 if (fallback)
1951 location_inputs_in_use_fallback.insert(location + i);
1952 }
1953 }
1954
get_target_components_for_fragment_location(uint32_t location) const1955 uint32_t CompilerMSL::get_target_components_for_fragment_location(uint32_t location) const
1956 {
1957 auto itr = fragment_output_components.find(location);
1958 if (itr == end(fragment_output_components))
1959 return 4;
1960 else
1961 return itr->second;
1962 }
1963
build_extended_vector_type(uint32_t type_id,uint32_t components,SPIRType::BaseType basetype)1964 uint32_t CompilerMSL::build_extended_vector_type(uint32_t type_id, uint32_t components, SPIRType::BaseType basetype)
1965 {
1966 uint32_t new_type_id = ir.increase_bound_by(1);
1967 auto &old_type = get<SPIRType>(type_id);
1968 auto *type = &set<SPIRType>(new_type_id, old_type);
1969 type->vecsize = components;
1970 if (basetype != SPIRType::Unknown)
1971 type->basetype = basetype;
1972 type->self = new_type_id;
1973 type->parent_type = type_id;
1974 type->array.clear();
1975 type->array_size_literal.clear();
1976 type->pointer = false;
1977
1978 if (is_array(old_type))
1979 {
1980 uint32_t array_type_id = ir.increase_bound_by(1);
1981 type = &set<SPIRType>(array_type_id, *type);
1982 type->parent_type = new_type_id;
1983 type->array = old_type.array;
1984 type->array_size_literal = old_type.array_size_literal;
1985 new_type_id = array_type_id;
1986 }
1987
1988 if (old_type.pointer)
1989 {
1990 uint32_t ptr_type_id = ir.increase_bound_by(1);
1991 type = &set<SPIRType>(ptr_type_id, *type);
1992 type->self = new_type_id;
1993 type->parent_type = new_type_id;
1994 type->storage = old_type.storage;
1995 type->pointer = true;
1996 type->pointer_depth++;
1997 new_type_id = ptr_type_id;
1998 }
1999
2000 return new_type_id;
2001 }
2002
build_msl_interpolant_type(uint32_t type_id,bool is_noperspective)2003 uint32_t CompilerMSL::build_msl_interpolant_type(uint32_t type_id, bool is_noperspective)
2004 {
2005 uint32_t new_type_id = ir.increase_bound_by(1);
2006 SPIRType &type = set<SPIRType>(new_type_id, get<SPIRType>(type_id));
2007 type.basetype = SPIRType::Interpolant;
2008 type.parent_type = type_id;
2009 // In Metal, the pull-model interpolant type encodes perspective-vs-no-perspective in the type itself.
2010 // Add this decoration so we know which argument to pass to the template.
2011 if (is_noperspective)
2012 set_decoration(new_type_id, DecorationNoPerspective);
2013 return new_type_id;
2014 }
2015
add_component_variable_to_interface_block(spv::StorageClass storage,const std::string & ib_var_ref,SPIRVariable & var,const SPIRType & type,InterfaceBlockMeta & meta)2016 bool CompilerMSL::add_component_variable_to_interface_block(spv::StorageClass storage, const std::string &ib_var_ref,
2017 SPIRVariable &var,
2018 const SPIRType &type,
2019 InterfaceBlockMeta &meta)
2020 {
2021 // Deal with Component decorations.
2022 const InterfaceBlockMeta::LocationMeta *location_meta = nullptr;
2023 uint32_t location = ~0u;
2024 if (has_decoration(var.self, DecorationLocation))
2025 {
2026 location = get_decoration(var.self, DecorationLocation);
2027 auto location_meta_itr = meta.location_meta.find(location);
2028 if (location_meta_itr != end(meta.location_meta))
2029 location_meta = &location_meta_itr->second;
2030 }
2031
2032 // Check if we need to pad fragment output to match a certain number of components.
2033 if (location_meta)
2034 {
2035 bool pad_fragment_output = has_decoration(var.self, DecorationLocation) &&
2036 msl_options.pad_fragment_output_components &&
2037 get_entry_point().model == ExecutionModelFragment && storage == StorageClassOutput;
2038
2039 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2040 uint32_t start_component = get_decoration(var.self, DecorationComponent);
2041 uint32_t type_components = type.vecsize;
2042 uint32_t num_components = location_meta->num_components;
2043
2044 if (pad_fragment_output)
2045 {
2046 uint32_t locn = get_decoration(var.self, DecorationLocation);
2047 num_components = std::max(num_components, get_target_components_for_fragment_location(locn));
2048 }
2049
2050 // We have already declared an IO block member as m_location_N.
2051 // Just emit an early-declared variable and fixup as needed.
2052 // Arrays need to be unrolled here since each location might need a different number of components.
2053 entry_func.add_local_variable(var.self);
2054 vars_needing_early_declaration.push_back(var.self);
2055
2056 if (var.storage == StorageClassInput)
2057 {
2058 entry_func.fixup_hooks_in.push_back([=, &type, &var]() {
2059 if (!type.array.empty())
2060 {
2061 uint32_t array_size = to_array_size_literal(type);
2062 for (uint32_t loc_off = 0; loc_off < array_size; loc_off++)
2063 {
2064 statement(to_name(var.self), "[", loc_off, "]", " = ", ib_var_ref,
2065 ".m_location_", location + loc_off,
2066 vector_swizzle(type_components, start_component), ";");
2067 }
2068 }
2069 else
2070 {
2071 statement(to_name(var.self), " = ", ib_var_ref, ".m_location_", location,
2072 vector_swizzle(type_components, start_component), ";");
2073 }
2074 });
2075 }
2076 else
2077 {
2078 entry_func.fixup_hooks_out.push_back([=, &type, &var]() {
2079 if (!type.array.empty())
2080 {
2081 uint32_t array_size = to_array_size_literal(type);
2082 for (uint32_t loc_off = 0; loc_off < array_size; loc_off++)
2083 {
2084 statement(ib_var_ref, ".m_location_", location + loc_off,
2085 vector_swizzle(type_components, start_component), " = ",
2086 to_name(var.self), "[", loc_off, "];");
2087 }
2088 }
2089 else
2090 {
2091 statement(ib_var_ref, ".m_location_", location,
2092 vector_swizzle(type_components, start_component), " = ", to_name(var.self), ";");
2093 }
2094 });
2095 }
2096 return true;
2097 }
2098 else
2099 return false;
2100 }
2101
add_plain_variable_to_interface_block(StorageClass storage,const string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var,InterfaceBlockMeta & meta)2102 void CompilerMSL::add_plain_variable_to_interface_block(StorageClass storage, const string &ib_var_ref,
2103 SPIRType &ib_type, SPIRVariable &var, InterfaceBlockMeta &meta)
2104 {
2105 bool is_builtin = is_builtin_variable(var);
2106 BuiltIn builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
2107 bool is_flat = has_decoration(var.self, DecorationFlat);
2108 bool is_noperspective = has_decoration(var.self, DecorationNoPerspective);
2109 bool is_centroid = has_decoration(var.self, DecorationCentroid);
2110 bool is_sample = has_decoration(var.self, DecorationSample);
2111
2112 // Add a reference to the variable type to the interface struct.
2113 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
2114 uint32_t type_id = ensure_correct_builtin_type(var.basetype, builtin);
2115 var.basetype = type_id;
2116
2117 type_id = get_pointee_type_id(var.basetype);
2118 if (meta.strip_array && is_array(get<SPIRType>(type_id)))
2119 type_id = get<SPIRType>(type_id).parent_type;
2120 auto &type = get<SPIRType>(type_id);
2121 uint32_t target_components = 0;
2122 uint32_t type_components = type.vecsize;
2123
2124 bool padded_output = false;
2125 bool padded_input = false;
2126 uint32_t start_component = 0;
2127
2128 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2129
2130 if (add_component_variable_to_interface_block(storage, ib_var_ref, var, type, meta))
2131 return;
2132
2133 bool pad_fragment_output = has_decoration(var.self, DecorationLocation) &&
2134 msl_options.pad_fragment_output_components &&
2135 get_entry_point().model == ExecutionModelFragment && storage == StorageClassOutput;
2136
2137 if (pad_fragment_output)
2138 {
2139 uint32_t locn = get_decoration(var.self, DecorationLocation);
2140 target_components = get_target_components_for_fragment_location(locn);
2141 if (type_components < target_components)
2142 {
2143 // Make a new type here.
2144 type_id = build_extended_vector_type(type_id, target_components);
2145 padded_output = true;
2146 }
2147 }
2148
2149 if (storage == StorageClassInput && pull_model_inputs.count(var.self))
2150 ib_type.member_types.push_back(build_msl_interpolant_type(type_id, is_noperspective));
2151 else
2152 ib_type.member_types.push_back(type_id);
2153
2154 // Give the member a name
2155 string mbr_name = ensure_valid_name(to_expression(var.self), "m");
2156 set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
2157
2158 // Update the original variable reference to include the structure reference
2159 string qual_var_name = ib_var_ref + "." + mbr_name;
2160 // If using pull-model interpolation, need to add a call to the correct interpolation method.
2161 if (storage == StorageClassInput && pull_model_inputs.count(var.self))
2162 {
2163 if (is_centroid)
2164 qual_var_name += ".interpolate_at_centroid()";
2165 else if (is_sample)
2166 qual_var_name += join(".interpolate_at_sample(", to_expression(builtin_sample_id_id), ")");
2167 else
2168 qual_var_name += ".interpolate_at_center()";
2169 }
2170
2171 if (padded_output || padded_input)
2172 {
2173 entry_func.add_local_variable(var.self);
2174 vars_needing_early_declaration.push_back(var.self);
2175
2176 if (padded_output)
2177 {
2178 entry_func.fixup_hooks_out.push_back([=, &var]() {
2179 statement(qual_var_name, vector_swizzle(type_components, start_component), " = ", to_name(var.self),
2180 ";");
2181 });
2182 }
2183 else
2184 {
2185 entry_func.fixup_hooks_in.push_back([=, &var]() {
2186 statement(to_name(var.self), " = ", qual_var_name, vector_swizzle(type_components, start_component),
2187 ";");
2188 });
2189 }
2190 }
2191 else if (!meta.strip_array)
2192 ir.meta[var.self].decoration.qualified_alias = qual_var_name;
2193
2194 if (var.storage == StorageClassOutput && var.initializer != ID(0))
2195 {
2196 if (padded_output || padded_input)
2197 {
2198 entry_func.fixup_hooks_in.push_back(
2199 [=, &var]() { statement(to_name(var.self), " = ", to_expression(var.initializer), ";"); });
2200 }
2201 else
2202 {
2203 if (meta.strip_array)
2204 {
2205 entry_func.fixup_hooks_in.push_back([=, &var]() {
2206 uint32_t index = get_extended_decoration(var.self, SPIRVCrossDecorationInterfaceMemberIndex);
2207 auto invocation = to_tesc_invocation_id();
2208 statement(to_expression(stage_out_ptr_var_id), "[",
2209 invocation, "].",
2210 to_member_name(ib_type, index), " = ", to_expression(var.initializer), "[",
2211 invocation, "];");
2212 });
2213 }
2214 else
2215 {
2216 entry_func.fixup_hooks_in.push_back([=, &var]() {
2217 statement(qual_var_name, " = ", to_expression(var.initializer), ";");
2218 });
2219 }
2220 }
2221 }
2222
2223 // Copy the variable location from the original variable to the member
2224 if (get_decoration_bitset(var.self).get(DecorationLocation))
2225 {
2226 uint32_t locn = get_decoration(var.self, DecorationLocation);
2227 if (storage == StorageClassInput)
2228 {
2229 type_id = ensure_correct_input_type(var.basetype, locn, 0, meta.strip_array);
2230 var.basetype = type_id;
2231
2232 type_id = get_pointee_type_id(type_id);
2233 if (meta.strip_array && is_array(get<SPIRType>(type_id)))
2234 type_id = get<SPIRType>(type_id).parent_type;
2235 if (pull_model_inputs.count(var.self))
2236 ib_type.member_types[ib_mbr_idx] = build_msl_interpolant_type(type_id, is_noperspective);
2237 else
2238 ib_type.member_types[ib_mbr_idx] = type_id;
2239 }
2240 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2241 mark_location_as_used_by_shader(locn, get<SPIRType>(type_id), storage);
2242 }
2243 else if (is_builtin && is_tessellation_shader() && inputs_by_builtin.count(builtin))
2244 {
2245 uint32_t locn = inputs_by_builtin[builtin].location;
2246 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2247 mark_location_as_used_by_shader(locn, type, storage);
2248 }
2249
2250 if (get_decoration_bitset(var.self).get(DecorationComponent))
2251 {
2252 uint32_t component = get_decoration(var.self, DecorationComponent);
2253 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationComponent, component);
2254 }
2255
2256 if (get_decoration_bitset(var.self).get(DecorationIndex))
2257 {
2258 uint32_t index = get_decoration(var.self, DecorationIndex);
2259 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationIndex, index);
2260 }
2261
2262 // Mark the member as builtin if needed
2263 if (is_builtin)
2264 {
2265 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, builtin);
2266 if (builtin == BuiltInPosition && storage == StorageClassOutput)
2267 qual_pos_var_name = qual_var_name;
2268 }
2269
2270 // Copy interpolation decorations if needed
2271 if (storage != StorageClassInput || !pull_model_inputs.count(var.self))
2272 {
2273 if (is_flat)
2274 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat);
2275 if (is_noperspective)
2276 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective);
2277 if (is_centroid)
2278 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid);
2279 if (is_sample)
2280 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample);
2281 }
2282
2283 set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceOrigID, var.self);
2284 }
2285
add_composite_variable_to_interface_block(StorageClass storage,const string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var,InterfaceBlockMeta & meta)2286 void CompilerMSL::add_composite_variable_to_interface_block(StorageClass storage, const string &ib_var_ref,
2287 SPIRType &ib_type, SPIRVariable &var,
2288 InterfaceBlockMeta &meta)
2289 {
2290 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2291 auto &var_type = meta.strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
2292 uint32_t elem_cnt = 0;
2293
2294 if (add_component_variable_to_interface_block(storage, ib_var_ref, var, var_type, meta))
2295 return;
2296
2297 if (is_matrix(var_type))
2298 {
2299 if (is_array(var_type))
2300 SPIRV_CROSS_THROW("MSL cannot emit arrays-of-matrices in input and output variables.");
2301
2302 elem_cnt = var_type.columns;
2303 }
2304 else if (is_array(var_type))
2305 {
2306 if (var_type.array.size() != 1)
2307 SPIRV_CROSS_THROW("MSL cannot emit arrays-of-arrays in input and output variables.");
2308
2309 elem_cnt = to_array_size_literal(var_type);
2310 }
2311
2312 bool is_builtin = is_builtin_variable(var);
2313 BuiltIn builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
2314 bool is_flat = has_decoration(var.self, DecorationFlat);
2315 bool is_noperspective = has_decoration(var.self, DecorationNoPerspective);
2316 bool is_centroid = has_decoration(var.self, DecorationCentroid);
2317 bool is_sample = has_decoration(var.self, DecorationSample);
2318
2319 auto *usable_type = &var_type;
2320 if (usable_type->pointer)
2321 usable_type = &get<SPIRType>(usable_type->parent_type);
2322 while (is_array(*usable_type) || is_matrix(*usable_type))
2323 usable_type = &get<SPIRType>(usable_type->parent_type);
2324
2325 // If a builtin, force it to have the proper name.
2326 if (is_builtin)
2327 set_name(var.self, builtin_to_glsl(builtin, StorageClassFunction));
2328
2329 bool flatten_from_ib_var = false;
2330 string flatten_from_ib_mbr_name;
2331
2332 if (storage == StorageClassOutput && is_builtin && builtin == BuiltInClipDistance)
2333 {
2334 // Also declare [[clip_distance]] attribute here.
2335 uint32_t clip_array_mbr_idx = uint32_t(ib_type.member_types.size());
2336 ib_type.member_types.push_back(get_variable_data_type_id(var));
2337 set_member_decoration(ib_type.self, clip_array_mbr_idx, DecorationBuiltIn, BuiltInClipDistance);
2338
2339 flatten_from_ib_mbr_name = builtin_to_glsl(BuiltInClipDistance, StorageClassOutput);
2340 set_member_name(ib_type.self, clip_array_mbr_idx, flatten_from_ib_mbr_name);
2341
2342 // When we flatten, we flatten directly from the "out" struct,
2343 // not from a function variable.
2344 flatten_from_ib_var = true;
2345
2346 if (!msl_options.enable_clip_distance_user_varying)
2347 return;
2348 }
2349 else if (!meta.strip_array)
2350 {
2351 // Only flatten/unflatten IO composites for non-tessellation cases where arrays are not stripped.
2352 entry_func.add_local_variable(var.self);
2353 // We need to declare the variable early and at entry-point scope.
2354 vars_needing_early_declaration.push_back(var.self);
2355 }
2356
2357 for (uint32_t i = 0; i < elem_cnt; i++)
2358 {
2359 // Add a reference to the variable type to the interface struct.
2360 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
2361
2362 uint32_t target_components = 0;
2363 bool padded_output = false;
2364 uint32_t type_id = usable_type->self;
2365
2366 // Check if we need to pad fragment output to match a certain number of components.
2367 if (get_decoration_bitset(var.self).get(DecorationLocation) && msl_options.pad_fragment_output_components &&
2368 get_entry_point().model == ExecutionModelFragment && storage == StorageClassOutput)
2369 {
2370 uint32_t locn = get_decoration(var.self, DecorationLocation) + i;
2371 target_components = get_target_components_for_fragment_location(locn);
2372 if (usable_type->vecsize < target_components)
2373 {
2374 // Make a new type here.
2375 type_id = build_extended_vector_type(usable_type->self, target_components);
2376 padded_output = true;
2377 }
2378 }
2379
2380 if (storage == StorageClassInput && pull_model_inputs.count(var.self))
2381 ib_type.member_types.push_back(build_msl_interpolant_type(get_pointee_type_id(type_id), is_noperspective));
2382 else
2383 ib_type.member_types.push_back(get_pointee_type_id(type_id));
2384
2385 // Give the member a name
2386 string mbr_name = ensure_valid_name(join(to_expression(var.self), "_", i), "m");
2387 set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
2388
2389 // There is no qualified alias since we need to flatten the internal array on return.
2390 if (get_decoration_bitset(var.self).get(DecorationLocation))
2391 {
2392 uint32_t locn = get_decoration(var.self, DecorationLocation) + i;
2393 uint32_t comp = get_decoration(var.self, DecorationComponent);
2394 if (storage == StorageClassInput)
2395 {
2396 var.basetype = ensure_correct_input_type(var.basetype, locn, 0, meta.strip_array);
2397 uint32_t mbr_type_id = ensure_correct_input_type(usable_type->self, locn, 0, meta.strip_array);
2398 if (storage == StorageClassInput && pull_model_inputs.count(var.self))
2399 ib_type.member_types[ib_mbr_idx] = build_msl_interpolant_type(mbr_type_id, is_noperspective);
2400 else
2401 ib_type.member_types[ib_mbr_idx] = mbr_type_id;
2402 }
2403 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2404 if (comp)
2405 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationComponent, comp);
2406 mark_location_as_used_by_shader(locn, *usable_type, storage);
2407 }
2408 else if (is_builtin && is_tessellation_shader() && inputs_by_builtin.count(builtin))
2409 {
2410 uint32_t locn = inputs_by_builtin[builtin].location + i;
2411 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2412 mark_location_as_used_by_shader(locn, *usable_type, storage);
2413 }
2414 else if (is_builtin && (builtin == BuiltInClipDistance || builtin == BuiltInCullDistance))
2415 {
2416 // Declare the Clip/CullDistance as [[user(clip/cullN)]].
2417 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, builtin);
2418 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationIndex, i);
2419 }
2420
2421 if (get_decoration_bitset(var.self).get(DecorationIndex))
2422 {
2423 uint32_t index = get_decoration(var.self, DecorationIndex);
2424 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationIndex, index);
2425 }
2426
2427 if (storage != StorageClassInput || !pull_model_inputs.count(var.self))
2428 {
2429 // Copy interpolation decorations if needed
2430 if (is_flat)
2431 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat);
2432 if (is_noperspective)
2433 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective);
2434 if (is_centroid)
2435 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid);
2436 if (is_sample)
2437 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample);
2438 }
2439
2440 set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceOrigID, var.self);
2441
2442 // Only flatten/unflatten IO composites for non-tessellation cases where arrays are not stripped.
2443 if (!meta.strip_array)
2444 {
2445 switch (storage)
2446 {
2447 case StorageClassInput:
2448 entry_func.fixup_hooks_in.push_back([=, &var]() {
2449 if (pull_model_inputs.count(var.self))
2450 {
2451 string lerp_call;
2452 if (is_centroid)
2453 lerp_call = ".interpolate_at_centroid()";
2454 else if (is_sample)
2455 lerp_call = join(".interpolate_at_sample(", to_expression(builtin_sample_id_id), ")");
2456 else
2457 lerp_call = ".interpolate_at_center()";
2458 statement(to_name(var.self), "[", i, "] = ", ib_var_ref, ".", mbr_name, lerp_call, ";");
2459 }
2460 else
2461 {
2462 statement(to_name(var.self), "[", i, "] = ", ib_var_ref, ".", mbr_name, ";");
2463 }
2464 });
2465 break;
2466
2467 case StorageClassOutput:
2468 entry_func.fixup_hooks_out.push_back([=, &var]() {
2469 if (padded_output)
2470 {
2471 auto &padded_type = this->get<SPIRType>(type_id);
2472 statement(
2473 ib_var_ref, ".", mbr_name, " = ",
2474 remap_swizzle(padded_type, usable_type->vecsize, join(to_name(var.self), "[", i, "]")),
2475 ";");
2476 }
2477 else if (flatten_from_ib_var)
2478 statement(ib_var_ref, ".", mbr_name, " = ", ib_var_ref, ".", flatten_from_ib_mbr_name, "[", i,
2479 "];");
2480 else
2481 statement(ib_var_ref, ".", mbr_name, " = ", to_name(var.self), "[", i, "];");
2482 });
2483 break;
2484
2485 default:
2486 break;
2487 }
2488 }
2489 }
2490 }
2491
add_composite_member_variable_to_interface_block(StorageClass storage,const string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var,uint32_t mbr_idx,InterfaceBlockMeta & meta)2492 void CompilerMSL::add_composite_member_variable_to_interface_block(StorageClass storage, const string &ib_var_ref,
2493 SPIRType &ib_type, SPIRVariable &var,
2494 uint32_t mbr_idx, InterfaceBlockMeta &meta)
2495 {
2496 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2497 auto &var_type = meta.strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
2498
2499 BuiltIn builtin = BuiltInMax;
2500 bool is_builtin = is_member_builtin(var_type, mbr_idx, &builtin);
2501 bool is_flat =
2502 has_member_decoration(var_type.self, mbr_idx, DecorationFlat) || has_decoration(var.self, DecorationFlat);
2503 bool is_noperspective = has_member_decoration(var_type.self, mbr_idx, DecorationNoPerspective) ||
2504 has_decoration(var.self, DecorationNoPerspective);
2505 bool is_centroid = has_member_decoration(var_type.self, mbr_idx, DecorationCentroid) ||
2506 has_decoration(var.self, DecorationCentroid);
2507 bool is_sample =
2508 has_member_decoration(var_type.self, mbr_idx, DecorationSample) || has_decoration(var.self, DecorationSample);
2509
2510 uint32_t mbr_type_id = var_type.member_types[mbr_idx];
2511 auto &mbr_type = get<SPIRType>(mbr_type_id);
2512 uint32_t elem_cnt = 0;
2513
2514 if (is_matrix(mbr_type))
2515 {
2516 if (is_array(mbr_type))
2517 SPIRV_CROSS_THROW("MSL cannot emit arrays-of-matrices in input and output variables.");
2518
2519 elem_cnt = mbr_type.columns;
2520 }
2521 else if (is_array(mbr_type))
2522 {
2523 if (mbr_type.array.size() != 1)
2524 SPIRV_CROSS_THROW("MSL cannot emit arrays-of-arrays in input and output variables.");
2525
2526 elem_cnt = to_array_size_literal(mbr_type);
2527 }
2528
2529 auto *usable_type = &mbr_type;
2530 if (usable_type->pointer)
2531 usable_type = &get<SPIRType>(usable_type->parent_type);
2532 while (is_array(*usable_type) || is_matrix(*usable_type))
2533 usable_type = &get<SPIRType>(usable_type->parent_type);
2534
2535 bool flatten_from_ib_var = false;
2536 string flatten_from_ib_mbr_name;
2537
2538 if (storage == StorageClassOutput && is_builtin && builtin == BuiltInClipDistance)
2539 {
2540 // Also declare [[clip_distance]] attribute here.
2541 uint32_t clip_array_mbr_idx = uint32_t(ib_type.member_types.size());
2542 ib_type.member_types.push_back(mbr_type_id);
2543 set_member_decoration(ib_type.self, clip_array_mbr_idx, DecorationBuiltIn, BuiltInClipDistance);
2544
2545 flatten_from_ib_mbr_name = builtin_to_glsl(BuiltInClipDistance, StorageClassOutput);
2546 set_member_name(ib_type.self, clip_array_mbr_idx, flatten_from_ib_mbr_name);
2547
2548 // When we flatten, we flatten directly from the "out" struct,
2549 // not from a function variable.
2550 flatten_from_ib_var = true;
2551
2552 if (!msl_options.enable_clip_distance_user_varying)
2553 return;
2554 }
2555
2556 for (uint32_t i = 0; i < elem_cnt; i++)
2557 {
2558 // Add a reference to the variable type to the interface struct.
2559 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
2560 if (storage == StorageClassInput && pull_model_inputs.count(var.self))
2561 ib_type.member_types.push_back(build_msl_interpolant_type(usable_type->self, is_noperspective));
2562 else
2563 ib_type.member_types.push_back(usable_type->self);
2564
2565 // Give the member a name
2566 string mbr_name = ensure_valid_name(join(to_qualified_member_name(var_type, mbr_idx), "_", i), "m");
2567 set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
2568
2569 if (has_member_decoration(var_type.self, mbr_idx, DecorationLocation))
2570 {
2571 uint32_t locn = get_member_decoration(var_type.self, mbr_idx, DecorationLocation) + i;
2572 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2573 mark_location_as_used_by_shader(locn, *usable_type, storage);
2574 }
2575 else if (has_decoration(var.self, DecorationLocation))
2576 {
2577 uint32_t locn = get_accumulated_member_location(var, mbr_idx, meta.strip_array) + i;
2578 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2579 mark_location_as_used_by_shader(locn, *usable_type, storage);
2580 }
2581 else if (is_builtin && is_tessellation_shader() && inputs_by_builtin.count(builtin))
2582 {
2583 uint32_t locn = inputs_by_builtin[builtin].location + i;
2584 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2585 mark_location_as_used_by_shader(locn, *usable_type, storage);
2586 }
2587 else if (is_builtin && (builtin == BuiltInClipDistance || builtin == BuiltInCullDistance))
2588 {
2589 // Declare the Clip/CullDistance as [[user(clip/cullN)]].
2590 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, builtin);
2591 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationIndex, i);
2592 }
2593
2594 if (has_member_decoration(var_type.self, mbr_idx, DecorationComponent))
2595 SPIRV_CROSS_THROW("DecorationComponent on matrices and arrays make little sense.");
2596
2597 if (storage != StorageClassInput || !pull_model_inputs.count(var.self))
2598 {
2599 // Copy interpolation decorations if needed
2600 if (is_flat)
2601 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat);
2602 if (is_noperspective)
2603 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective);
2604 if (is_centroid)
2605 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid);
2606 if (is_sample)
2607 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample);
2608 }
2609
2610 set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceOrigID, var.self);
2611 set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex, mbr_idx);
2612
2613 // Unflatten or flatten from [[stage_in]] or [[stage_out]] as appropriate.
2614 if (!meta.strip_array && meta.allow_local_declaration)
2615 {
2616 switch (storage)
2617 {
2618 case StorageClassInput:
2619 entry_func.fixup_hooks_in.push_back([=, &var, &var_type]() {
2620 if (pull_model_inputs.count(var.self))
2621 {
2622 string lerp_call;
2623 if (is_centroid)
2624 lerp_call = ".interpolate_at_centroid()";
2625 else if (is_sample)
2626 lerp_call = join(".interpolate_at_sample(", to_expression(builtin_sample_id_id), ")");
2627 else
2628 lerp_call = ".interpolate_at_center()";
2629 statement(to_name(var.self), ".", to_member_name(var_type, mbr_idx), "[", i, "] = ", ib_var_ref,
2630 ".", mbr_name, lerp_call, ";");
2631 }
2632 else
2633 {
2634 statement(to_name(var.self), ".", to_member_name(var_type, mbr_idx), "[", i, "] = ", ib_var_ref,
2635 ".", mbr_name, ";");
2636 }
2637 });
2638 break;
2639
2640 case StorageClassOutput:
2641 entry_func.fixup_hooks_out.push_back([=, &var, &var_type]() {
2642 if (flatten_from_ib_var)
2643 {
2644 statement(ib_var_ref, ".", mbr_name, " = ", ib_var_ref, ".", flatten_from_ib_mbr_name, "[", i,
2645 "];");
2646 }
2647 else
2648 {
2649 statement(ib_var_ref, ".", mbr_name, " = ", to_name(var.self), ".",
2650 to_member_name(var_type, mbr_idx), "[", i, "];");
2651 }
2652 });
2653 break;
2654
2655 default:
2656 break;
2657 }
2658 }
2659 }
2660 }
2661
add_plain_member_variable_to_interface_block(StorageClass storage,const string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var,uint32_t mbr_idx,InterfaceBlockMeta & meta)2662 void CompilerMSL::add_plain_member_variable_to_interface_block(StorageClass storage, const string &ib_var_ref,
2663 SPIRType &ib_type, SPIRVariable &var, uint32_t mbr_idx,
2664 InterfaceBlockMeta &meta)
2665 {
2666 auto &var_type = meta.strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
2667 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2668
2669 BuiltIn builtin = BuiltInMax;
2670 bool is_builtin = is_member_builtin(var_type, mbr_idx, &builtin);
2671 bool is_flat =
2672 has_member_decoration(var_type.self, mbr_idx, DecorationFlat) || has_decoration(var.self, DecorationFlat);
2673 bool is_noperspective = has_member_decoration(var_type.self, mbr_idx, DecorationNoPerspective) ||
2674 has_decoration(var.self, DecorationNoPerspective);
2675 bool is_centroid = has_member_decoration(var_type.self, mbr_idx, DecorationCentroid) ||
2676 has_decoration(var.self, DecorationCentroid);
2677 bool is_sample =
2678 has_member_decoration(var_type.self, mbr_idx, DecorationSample) || has_decoration(var.self, DecorationSample);
2679
2680 // Add a reference to the member to the interface struct.
2681 uint32_t mbr_type_id = var_type.member_types[mbr_idx];
2682 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
2683 mbr_type_id = ensure_correct_builtin_type(mbr_type_id, builtin);
2684 var_type.member_types[mbr_idx] = mbr_type_id;
2685 if (storage == StorageClassInput && pull_model_inputs.count(var.self))
2686 ib_type.member_types.push_back(build_msl_interpolant_type(mbr_type_id, is_noperspective));
2687 else
2688 ib_type.member_types.push_back(mbr_type_id);
2689
2690 // Give the member a name
2691 string mbr_name = ensure_valid_name(to_qualified_member_name(var_type, mbr_idx), "m");
2692 set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
2693
2694 // Update the original variable reference to include the structure reference
2695 string qual_var_name = ib_var_ref + "." + mbr_name;
2696 // If using pull-model interpolation, need to add a call to the correct interpolation method.
2697 if (storage == StorageClassInput && pull_model_inputs.count(var.self))
2698 {
2699 if (is_centroid)
2700 qual_var_name += ".interpolate_at_centroid()";
2701 else if (is_sample)
2702 qual_var_name += join(".interpolate_at_sample(", to_expression(builtin_sample_id_id), ")");
2703 else
2704 qual_var_name += ".interpolate_at_center()";
2705 }
2706
2707 bool flatten_stage_out = false;
2708
2709 if (is_builtin && !meta.strip_array)
2710 {
2711 // For the builtin gl_PerVertex, we cannot treat it as a block anyways,
2712 // so redirect to qualified name.
2713 set_member_qualified_name(var_type.self, mbr_idx, qual_var_name);
2714 }
2715 else if (!meta.strip_array && meta.allow_local_declaration)
2716 {
2717 // Unflatten or flatten from [[stage_in]] or [[stage_out]] as appropriate.
2718 switch (storage)
2719 {
2720 case StorageClassInput:
2721 entry_func.fixup_hooks_in.push_back([=, &var, &var_type]() {
2722 statement(to_name(var.self), ".", to_member_name(var_type, mbr_idx), " = ", qual_var_name, ";");
2723 });
2724 break;
2725
2726 case StorageClassOutput:
2727 flatten_stage_out = true;
2728 entry_func.fixup_hooks_out.push_back([=, &var, &var_type]() {
2729 statement(qual_var_name, " = ", to_name(var.self), ".", to_member_name(var_type, mbr_idx), ";");
2730 });
2731 break;
2732
2733 default:
2734 break;
2735 }
2736 }
2737
2738 // Copy the variable location from the original variable to the member
2739 if (has_member_decoration(var_type.self, mbr_idx, DecorationLocation))
2740 {
2741 uint32_t locn = get_member_decoration(var_type.self, mbr_idx, DecorationLocation);
2742 if (storage == StorageClassInput)
2743 {
2744 mbr_type_id = ensure_correct_input_type(mbr_type_id, locn, 0, meta.strip_array);
2745 var_type.member_types[mbr_idx] = mbr_type_id;
2746 if (storage == StorageClassInput && pull_model_inputs.count(var.self))
2747 ib_type.member_types[ib_mbr_idx] = build_msl_interpolant_type(mbr_type_id, is_noperspective);
2748 else
2749 ib_type.member_types[ib_mbr_idx] = mbr_type_id;
2750 }
2751 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2752 mark_location_as_used_by_shader(locn, get<SPIRType>(mbr_type_id), storage);
2753 }
2754 else if (has_decoration(var.self, DecorationLocation))
2755 {
2756 // The block itself might have a location and in this case, all members of the block
2757 // receive incrementing locations.
2758 uint32_t locn = get_accumulated_member_location(var, mbr_idx, meta.strip_array);
2759 if (storage == StorageClassInput)
2760 {
2761 mbr_type_id = ensure_correct_input_type(mbr_type_id, locn, 0, meta.strip_array);
2762 var_type.member_types[mbr_idx] = mbr_type_id;
2763 if (storage == StorageClassInput && pull_model_inputs.count(var.self))
2764 ib_type.member_types[ib_mbr_idx] = build_msl_interpolant_type(mbr_type_id, is_noperspective);
2765 else
2766 ib_type.member_types[ib_mbr_idx] = mbr_type_id;
2767 }
2768 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2769 mark_location_as_used_by_shader(locn, get<SPIRType>(mbr_type_id), storage);
2770 }
2771 else if (is_builtin && is_tessellation_shader() && inputs_by_builtin.count(builtin))
2772 {
2773 uint32_t locn = 0;
2774 auto builtin_itr = inputs_by_builtin.find(builtin);
2775 if (builtin_itr != end(inputs_by_builtin))
2776 locn = builtin_itr->second.location;
2777 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2778 mark_location_as_used_by_shader(locn, get<SPIRType>(mbr_type_id), storage);
2779 }
2780
2781 // Copy the component location, if present.
2782 if (has_member_decoration(var_type.self, mbr_idx, DecorationComponent))
2783 {
2784 uint32_t comp = get_member_decoration(var_type.self, mbr_idx, DecorationComponent);
2785 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationComponent, comp);
2786 }
2787
2788 // Mark the member as builtin if needed
2789 if (is_builtin)
2790 {
2791 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, builtin);
2792 if (builtin == BuiltInPosition && storage == StorageClassOutput)
2793 qual_pos_var_name = qual_var_name;
2794 }
2795
2796 const SPIRConstant *c = nullptr;
2797 if (!flatten_stage_out && var.storage == StorageClassOutput &&
2798 var.initializer != ID(0) && (c = maybe_get<SPIRConstant>(var.initializer)))
2799 {
2800 if (meta.strip_array)
2801 {
2802 entry_func.fixup_hooks_in.push_back([=, &var]() {
2803 auto &type = this->get<SPIRType>(var.basetype);
2804 uint32_t index = get_extended_member_decoration(var.self, mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex);
2805
2806 auto invocation = to_tesc_invocation_id();
2807 auto constant_chain = join(to_expression(var.initializer), "[", invocation, "]");
2808 statement(to_expression(stage_out_ptr_var_id), "[",
2809 invocation, "].",
2810 to_member_name(ib_type, index), " = ",
2811 constant_chain, ".", to_member_name(type, mbr_idx), ";");
2812 });
2813 }
2814 else
2815 {
2816 entry_func.fixup_hooks_in.push_back([=]() {
2817 statement(qual_var_name, " = ", constant_expression(
2818 this->get<SPIRConstant>(c->subconstants[mbr_idx])), ";");
2819 });
2820 }
2821 }
2822
2823 if (storage != StorageClassInput || !pull_model_inputs.count(var.self))
2824 {
2825 // Copy interpolation decorations if needed
2826 if (is_flat)
2827 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat);
2828 if (is_noperspective)
2829 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective);
2830 if (is_centroid)
2831 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid);
2832 if (is_sample)
2833 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample);
2834 }
2835
2836 set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceOrigID, var.self);
2837 set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex, mbr_idx);
2838 }
2839
2840 // In Metal, the tessellation levels are stored as tightly packed half-precision floating point values.
2841 // But, stage-in attribute offsets and strides must be multiples of four, so we can't pass the levels
2842 // individually. Therefore, we must pass them as vectors. Triangles get a single float4, with the outer
2843 // levels in 'xyz' and the inner level in 'w'. Quads get a float4 containing the outer levels and a
2844 // float2 containing the inner levels.
add_tess_level_input_to_interface_block(const std::string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var)2845 void CompilerMSL::add_tess_level_input_to_interface_block(const std::string &ib_var_ref, SPIRType &ib_type,
2846 SPIRVariable &var)
2847 {
2848 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2849 auto &var_type = get_variable_element_type(var);
2850
2851 BuiltIn builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
2852
2853 // Force the variable to have the proper name.
2854 string var_name = builtin_to_glsl(builtin, StorageClassFunction);
2855 set_name(var.self, var_name);
2856
2857 // We need to declare the variable early and at entry-point scope.
2858 entry_func.add_local_variable(var.self);
2859 vars_needing_early_declaration.push_back(var.self);
2860 bool triangles = get_execution_mode_bitset().get(ExecutionModeTriangles);
2861 string mbr_name;
2862
2863 // Add a reference to the variable type to the interface struct.
2864 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
2865
2866 const auto mark_locations = [&](const SPIRType &new_var_type) {
2867 if (get_decoration_bitset(var.self).get(DecorationLocation))
2868 {
2869 uint32_t locn = get_decoration(var.self, DecorationLocation);
2870 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2871 mark_location_as_used_by_shader(locn, new_var_type, StorageClassInput);
2872 }
2873 else if (inputs_by_builtin.count(builtin))
2874 {
2875 uint32_t locn = inputs_by_builtin[builtin].location;
2876 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2877 mark_location_as_used_by_shader(locn, new_var_type, StorageClassInput);
2878 }
2879 };
2880
2881 if (triangles)
2882 {
2883 // Triangles are tricky, because we want only one member in the struct.
2884 mbr_name = "gl_TessLevel";
2885
2886 // If we already added the other one, we can skip this step.
2887 if (!added_builtin_tess_level)
2888 {
2889 uint32_t type_id = build_extended_vector_type(var_type.self, 4);
2890
2891 ib_type.member_types.push_back(type_id);
2892
2893 // Give the member a name
2894 set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
2895
2896 // We cannot decorate both, but the important part is that
2897 // it's marked as builtin so we can get automatic attribute assignment if needed.
2898 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, builtin);
2899
2900 mark_locations(var_type);
2901 added_builtin_tess_level = true;
2902 }
2903 }
2904 else
2905 {
2906 mbr_name = var_name;
2907
2908 uint32_t type_id = build_extended_vector_type(var_type.self, builtin == BuiltInTessLevelOuter ? 4 : 2);
2909
2910 uint32_t ptr_type_id = ir.increase_bound_by(1);
2911 auto &new_var_type = set<SPIRType>(ptr_type_id, get<SPIRType>(type_id));
2912 new_var_type.pointer = true;
2913 new_var_type.pointer_depth++;
2914 new_var_type.storage = StorageClassInput;
2915 new_var_type.parent_type = type_id;
2916
2917 ib_type.member_types.push_back(type_id);
2918
2919 // Give the member a name
2920 set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
2921 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, builtin);
2922
2923 mark_locations(new_var_type);
2924 }
2925
2926 if (builtin == BuiltInTessLevelOuter)
2927 {
2928 entry_func.fixup_hooks_in.push_back([=]() {
2929 statement(var_name, "[0] = ", ib_var_ref, ".", mbr_name, ".x;");
2930 statement(var_name, "[1] = ", ib_var_ref, ".", mbr_name, ".y;");
2931 statement(var_name, "[2] = ", ib_var_ref, ".", mbr_name, ".z;");
2932 if (!triangles)
2933 statement(var_name, "[3] = ", ib_var_ref, ".", mbr_name, ".w;");
2934 });
2935 }
2936 else
2937 {
2938 entry_func.fixup_hooks_in.push_back([=]() {
2939 if (triangles)
2940 {
2941 statement(var_name, "[0] = ", ib_var_ref, ".", mbr_name, ".w;");
2942 }
2943 else
2944 {
2945 statement(var_name, "[0] = ", ib_var_ref, ".", mbr_name, ".x;");
2946 statement(var_name, "[1] = ", ib_var_ref, ".", mbr_name, ".y;");
2947 }
2948 });
2949 }
2950 }
2951
variable_storage_requires_stage_io(spv::StorageClass storage) const2952 bool CompilerMSL::variable_storage_requires_stage_io(spv::StorageClass storage) const
2953 {
2954 if (storage == StorageClassOutput)
2955 return !capture_output_to_buffer;
2956 else if (storage == StorageClassInput)
2957 return !(get_execution_model() == ExecutionModelTessellationControl && msl_options.multi_patch_workgroup);
2958 else
2959 return false;
2960 }
2961
to_tesc_invocation_id()2962 string CompilerMSL::to_tesc_invocation_id()
2963 {
2964 if (msl_options.multi_patch_workgroup)
2965 {
2966 // n.b. builtin_invocation_id_id here is the dispatch global invocation ID,
2967 // not the TC invocation ID.
2968 return join(to_expression(builtin_invocation_id_id), ".x % ", get_entry_point().output_vertices);
2969 }
2970 else
2971 return builtin_to_glsl(BuiltInInvocationId, StorageClassInput);
2972 }
2973
emit_local_masked_variable(const SPIRVariable & masked_var,bool strip_array)2974 void CompilerMSL::emit_local_masked_variable(const SPIRVariable &masked_var, bool strip_array)
2975 {
2976 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2977 bool threadgroup_storage = variable_decl_is_remapped_storage(masked_var, StorageClassWorkgroup);
2978
2979 if (threadgroup_storage && msl_options.multi_patch_workgroup)
2980 {
2981 // We need one threadgroup block per patch, so fake this.
2982 entry_func.fixup_hooks_in.push_back([this, &masked_var]() {
2983 auto &type = get_variable_data_type(masked_var);
2984 add_local_variable_name(masked_var.self);
2985
2986 bool old_is_builtin = is_using_builtin_array;
2987 is_using_builtin_array = true;
2988
2989 const uint32_t max_control_points_per_patch = 32u;
2990 uint32_t max_num_instances =
2991 (max_control_points_per_patch + get_entry_point().output_vertices - 1u) /
2992 get_entry_point().output_vertices;
2993 statement("threadgroup ", type_to_glsl(type), " ",
2994 "spvStorage", to_name(masked_var.self), "[", max_num_instances, "]",
2995 type_to_array_glsl(type), ";");
2996
2997 // Assign a threadgroup slice to each PrimitiveID.
2998 // We assume here that workgroup size is rounded to 32,
2999 // since that's the maximum number of control points per patch.
3000 // We cannot size the array based on fixed dispatch parameters,
3001 // since Metal does not allow that. :(
3002 // FIXME: We will likely need an option to support passing down target workgroup size,
3003 // so we can emit appropriate size here.
3004 statement("threadgroup ", type_to_glsl(type), " ",
3005 "(&", to_name(masked_var.self), ")",
3006 type_to_array_glsl(type), " = spvStorage", to_name(masked_var.self), "[",
3007 "(", to_expression(builtin_invocation_id_id), ".x / ",
3008 get_entry_point().output_vertices, ") % ",
3009 max_num_instances, "];");
3010
3011 is_using_builtin_array = old_is_builtin;
3012 });
3013 }
3014 else
3015 {
3016 entry_func.add_local_variable(masked_var.self);
3017 }
3018
3019 if (!threadgroup_storage)
3020 {
3021 vars_needing_early_declaration.push_back(masked_var.self);
3022 }
3023 else if (masked_var.initializer)
3024 {
3025 // Cannot directly initialize threadgroup variables. Need fixup hooks.
3026 ID initializer = masked_var.initializer;
3027 if (strip_array)
3028 {
3029 entry_func.fixup_hooks_in.push_back([this, &masked_var, initializer]() {
3030 auto invocation = to_tesc_invocation_id();
3031 statement(to_expression(masked_var.self), "[",
3032 invocation, "] = ",
3033 to_expression(initializer), "[",
3034 invocation, "];");
3035 });
3036 }
3037 else
3038 {
3039 entry_func.fixup_hooks_in.push_back([this, &masked_var, initializer]() {
3040 statement(to_expression(masked_var.self), " = ", to_expression(initializer), ";");
3041 });
3042 }
3043 }
3044 }
3045
add_variable_to_interface_block(StorageClass storage,const string & ib_var_ref,SPIRType & ib_type,SPIRVariable & var,InterfaceBlockMeta & meta)3046 void CompilerMSL::add_variable_to_interface_block(StorageClass storage, const string &ib_var_ref, SPIRType &ib_type,
3047 SPIRVariable &var, InterfaceBlockMeta &meta)
3048 {
3049 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
3050 // Tessellation control I/O variables and tessellation evaluation per-point inputs are
3051 // usually declared as arrays. In these cases, we want to add the element type to the
3052 // interface block, since in Metal it's the interface block itself which is arrayed.
3053 auto &var_type = meta.strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
3054 bool is_builtin = is_builtin_variable(var);
3055 auto builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
3056 bool is_block = has_decoration(var_type.self, DecorationBlock);
3057
3058 // If stage variables are masked out, emit them as plain variables instead.
3059 // For builtins, we query them one by one later.
3060 // IO blocks are not masked here, we need to mask them per-member instead.
3061 if (storage == StorageClassOutput && is_stage_output_variable_masked(var))
3062 {
3063 // If we ignore an output, we must still emit it, since it might be used by app.
3064 // Instead, just emit it as early declaration.
3065 emit_local_masked_variable(var, meta.strip_array);
3066 return;
3067 }
3068
3069 if (var_type.basetype == SPIRType::Struct)
3070 {
3071 bool block_requires_flattening = variable_storage_requires_stage_io(storage) || is_block;
3072 bool needs_local_declaration = !is_builtin && block_requires_flattening && meta.allow_local_declaration;
3073
3074 if (needs_local_declaration)
3075 {
3076 // For I/O blocks or structs, we will need to pass the block itself around
3077 // to functions if they are used globally in leaf functions.
3078 // Rather than passing down member by member,
3079 // we unflatten I/O blocks while running the shader,
3080 // and pass the actual struct type down to leaf functions.
3081 // We then unflatten inputs, and flatten outputs in the "fixup" stages.
3082 emit_local_masked_variable(var, meta.strip_array);
3083 }
3084
3085 if (!block_requires_flattening)
3086 {
3087 // In Metal tessellation shaders, the interface block itself is arrayed. This makes things
3088 // very complicated, since stage-in structures in MSL don't support nested structures.
3089 // Luckily, for stage-out when capturing output, we can avoid this and just add
3090 // composite members directly, because the stage-out structure is stored to a buffer,
3091 // not returned.
3092 add_plain_variable_to_interface_block(storage, ib_var_ref, ib_type, var, meta);
3093 }
3094 else
3095 {
3096 bool masked_block = false;
3097
3098 // Flatten the struct members into the interface struct
3099 for (uint32_t mbr_idx = 0; mbr_idx < uint32_t(var_type.member_types.size()); mbr_idx++)
3100 {
3101 builtin = BuiltInMax;
3102 is_builtin = is_member_builtin(var_type, mbr_idx, &builtin);
3103 auto &mbr_type = get<SPIRType>(var_type.member_types[mbr_idx]);
3104
3105 if (storage == StorageClassOutput && is_stage_output_block_member_masked(var, mbr_idx, meta.strip_array))
3106 {
3107 if (is_block)
3108 masked_block = true;
3109
3110 // Non-builtin block output variables are just ignored, since they will still access
3111 // the block variable as-is. They're just not flattened.
3112 if (is_builtin && !meta.strip_array)
3113 {
3114 // Emit a fake variable instead.
3115 uint32_t ids = ir.increase_bound_by(2);
3116 uint32_t ptr_type_id = ids + 0;
3117 uint32_t var_id = ids + 1;
3118
3119 auto ptr_type = mbr_type;
3120 ptr_type.pointer = true;
3121 ptr_type.pointer_depth++;
3122 ptr_type.parent_type = var_type.member_types[mbr_idx];
3123 ptr_type.storage = StorageClassOutput;
3124
3125 uint32_t initializer = 0;
3126 if (var.initializer)
3127 if (auto *c = maybe_get<SPIRConstant>(var.initializer))
3128 initializer = c->subconstants[mbr_idx];
3129
3130 set<SPIRType>(ptr_type_id, ptr_type);
3131 set<SPIRVariable>(var_id, ptr_type_id, StorageClassOutput, initializer);
3132 entry_func.add_local_variable(var_id);
3133 vars_needing_early_declaration.push_back(var_id);
3134 set_name(var_id, builtin_to_glsl(builtin, StorageClassOutput));
3135 set_decoration(var_id, DecorationBuiltIn, builtin);
3136 }
3137 }
3138 else if (!is_builtin || has_active_builtin(builtin, storage))
3139 {
3140 bool is_composite_type = is_matrix(mbr_type) || is_array(mbr_type);
3141 bool attribute_load_store =
3142 storage == StorageClassInput && get_execution_model() != ExecutionModelFragment;
3143 bool storage_is_stage_io = variable_storage_requires_stage_io(storage);
3144
3145 // Clip/CullDistance always need to be declared as user attributes.
3146 if (builtin == BuiltInClipDistance || builtin == BuiltInCullDistance)
3147 is_builtin = false;
3148
3149 if ((!is_builtin || attribute_load_store) && storage_is_stage_io && is_composite_type)
3150 {
3151 add_composite_member_variable_to_interface_block(storage, ib_var_ref, ib_type, var, mbr_idx,
3152 meta);
3153 }
3154 else
3155 {
3156 add_plain_member_variable_to_interface_block(storage, ib_var_ref, ib_type, var, mbr_idx, meta);
3157 }
3158 }
3159 }
3160
3161 // If we're redirecting a block, we might still need to access the original block
3162 // variable if we're masking some members.
3163 if (masked_block && !needs_local_declaration &&
3164 (!is_builtin_variable(var) || get_execution_model() == ExecutionModelTessellationControl))
3165 {
3166 if (is_builtin_variable(var))
3167 {
3168 // Ensure correct names for the block members if we're actually going to
3169 // declare gl_PerVertex.
3170 for (uint32_t mbr_idx = 0; mbr_idx < uint32_t(var_type.member_types.size()); mbr_idx++)
3171 {
3172 set_member_name(var_type.self, mbr_idx, builtin_to_glsl(
3173 BuiltIn(get_member_decoration(var_type.self, mbr_idx, DecorationBuiltIn)),
3174 StorageClassOutput));
3175 }
3176
3177 set_name(var_type.self, "gl_PerVertex");
3178 set_name(var.self, "gl_out_masked");
3179 stage_out_masked_builtin_type_id = var_type.self;
3180 }
3181 emit_local_masked_variable(var, meta.strip_array);
3182 }
3183 }
3184 }
3185 else if (get_execution_model() == ExecutionModelTessellationEvaluation && storage == StorageClassInput &&
3186 !meta.strip_array && is_builtin && (builtin == BuiltInTessLevelOuter || builtin == BuiltInTessLevelInner))
3187 {
3188 add_tess_level_input_to_interface_block(ib_var_ref, ib_type, var);
3189 }
3190 else if (var_type.basetype == SPIRType::Boolean || var_type.basetype == SPIRType::Char ||
3191 type_is_integral(var_type) || type_is_floating_point(var_type))
3192 {
3193 if (!is_builtin || has_active_builtin(builtin, storage))
3194 {
3195 bool is_composite_type = is_matrix(var_type) || is_array(var_type);
3196 bool storage_is_stage_io = variable_storage_requires_stage_io(storage);
3197 bool attribute_load_store = storage == StorageClassInput && get_execution_model() != ExecutionModelFragment;
3198
3199 // Clip/CullDistance always needs to be declared as user attributes.
3200 if (builtin == BuiltInClipDistance || builtin == BuiltInCullDistance)
3201 is_builtin = false;
3202
3203 // MSL does not allow matrices or arrays in input or output variables, so need to handle it specially.
3204 if ((!is_builtin || attribute_load_store) && storage_is_stage_io && is_composite_type)
3205 {
3206 add_composite_variable_to_interface_block(storage, ib_var_ref, ib_type, var, meta);
3207 }
3208 else
3209 {
3210 add_plain_variable_to_interface_block(storage, ib_var_ref, ib_type, var, meta);
3211 }
3212 }
3213 }
3214 }
3215
3216 // Fix up the mapping of variables to interface member indices, which is used to compile access chains
3217 // for per-vertex variables in a tessellation control shader.
fix_up_interface_member_indices(StorageClass storage,uint32_t ib_type_id)3218 void CompilerMSL::fix_up_interface_member_indices(StorageClass storage, uint32_t ib_type_id)
3219 {
3220 // Only needed for tessellation shaders and pull-model interpolants.
3221 // Need to redirect interface indices back to variables themselves.
3222 // For structs, each member of the struct need a separate instance.
3223 if (get_execution_model() != ExecutionModelTessellationControl &&
3224 !(get_execution_model() == ExecutionModelTessellationEvaluation && storage == StorageClassInput) &&
3225 !(get_execution_model() == ExecutionModelFragment && storage == StorageClassInput &&
3226 !pull_model_inputs.empty()))
3227 return;
3228
3229 auto mbr_cnt = uint32_t(ir.meta[ib_type_id].members.size());
3230 for (uint32_t i = 0; i < mbr_cnt; i++)
3231 {
3232 uint32_t var_id = get_extended_member_decoration(ib_type_id, i, SPIRVCrossDecorationInterfaceOrigID);
3233 if (!var_id)
3234 continue;
3235 auto &var = get<SPIRVariable>(var_id);
3236
3237 auto &type = get_variable_element_type(var);
3238
3239 bool flatten_composites = variable_storage_requires_stage_io(var.storage);
3240 bool is_block = has_decoration(type.self, DecorationBlock);
3241
3242 uint32_t mbr_idx = uint32_t(-1);
3243 if (type.basetype == SPIRType::Struct && (flatten_composites || is_block))
3244 mbr_idx = get_extended_member_decoration(ib_type_id, i, SPIRVCrossDecorationInterfaceMemberIndex);
3245
3246 if (mbr_idx != uint32_t(-1))
3247 {
3248 // Only set the lowest InterfaceMemberIndex for each variable member.
3249 // IB struct members will be emitted in-order w.r.t. interface member index.
3250 if (!has_extended_member_decoration(var_id, mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex))
3251 set_extended_member_decoration(var_id, mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex, i);
3252 }
3253 else
3254 {
3255 // Only set the lowest InterfaceMemberIndex for each variable.
3256 // IB struct members will be emitted in-order w.r.t. interface member index.
3257 if (!has_extended_decoration(var_id, SPIRVCrossDecorationInterfaceMemberIndex))
3258 set_extended_decoration(var_id, SPIRVCrossDecorationInterfaceMemberIndex, i);
3259 }
3260 }
3261 }
3262
3263 // Add an interface structure for the type of storage, which is either StorageClassInput or StorageClassOutput.
3264 // Returns the ID of the newly added variable, or zero if no variable was added.
add_interface_block(StorageClass storage,bool patch)3265 uint32_t CompilerMSL::add_interface_block(StorageClass storage, bool patch)
3266 {
3267 // Accumulate the variables that should appear in the interface struct.
3268 SmallVector<SPIRVariable *> vars;
3269 bool incl_builtins = storage == StorageClassOutput || is_tessellation_shader();
3270 bool has_seen_barycentric = false;
3271
3272 InterfaceBlockMeta meta;
3273
3274 // Varying interfaces between stages which use "user()" attribute can be dealt with
3275 // without explicit packing and unpacking of components. For any variables which link against the runtime
3276 // in some way (vertex attributes, fragment output, etc), we'll need to deal with it somehow.
3277 bool pack_components =
3278 (storage == StorageClassInput && get_execution_model() == ExecutionModelVertex) ||
3279 (storage == StorageClassOutput && get_execution_model() == ExecutionModelFragment) ||
3280 (storage == StorageClassOutput && get_execution_model() == ExecutionModelVertex && capture_output_to_buffer);
3281
3282 ir.for_each_typed_id<SPIRVariable>([&](uint32_t var_id, SPIRVariable &var) {
3283 if (var.storage != storage)
3284 return;
3285
3286 auto &type = this->get<SPIRType>(var.basetype);
3287
3288 bool is_builtin = is_builtin_variable(var);
3289 bool is_block = has_decoration(type.self, DecorationBlock);
3290
3291 auto bi_type = BuiltInMax;
3292 bool builtin_is_gl_in_out = false;
3293 if (is_builtin && !is_block)
3294 {
3295 bi_type = BuiltIn(get_decoration(var_id, DecorationBuiltIn));
3296 builtin_is_gl_in_out = bi_type == BuiltInPosition || bi_type == BuiltInPointSize ||
3297 bi_type == BuiltInClipDistance || bi_type == BuiltInCullDistance;
3298 }
3299
3300 if (is_builtin && is_block)
3301 builtin_is_gl_in_out = true;
3302
3303 uint32_t location = get_decoration(var_id, DecorationLocation);
3304
3305 bool builtin_is_stage_in_out = builtin_is_gl_in_out ||
3306 bi_type == BuiltInLayer || bi_type == BuiltInViewportIndex ||
3307 bi_type == BuiltInBaryCoordNV || bi_type == BuiltInBaryCoordNoPerspNV ||
3308 bi_type == BuiltInFragDepth ||
3309 bi_type == BuiltInFragStencilRefEXT || bi_type == BuiltInSampleMask;
3310
3311 // These builtins are part of the stage in/out structs.
3312 bool is_interface_block_builtin =
3313 builtin_is_stage_in_out ||
3314 (get_execution_model() == ExecutionModelTessellationEvaluation &&
3315 (bi_type == BuiltInTessLevelOuter || bi_type == BuiltInTessLevelInner));
3316
3317 bool is_active = interface_variable_exists_in_entry_point(var.self);
3318 if (is_builtin && is_active)
3319 {
3320 // Only emit the builtin if it's active in this entry point. Interface variable list might lie.
3321 if (is_block)
3322 {
3323 // If any builtin is active, the block is active.
3324 uint32_t mbr_cnt = uint32_t(type.member_types.size());
3325 for (uint32_t i = 0; !is_active && i < mbr_cnt; i++)
3326 is_active = has_active_builtin(BuiltIn(get_member_decoration(type.self, i, DecorationBuiltIn)), storage);
3327 }
3328 else
3329 {
3330 is_active = has_active_builtin(bi_type, storage);
3331 }
3332 }
3333
3334 bool filter_patch_decoration = (has_decoration(var_id, DecorationPatch) || is_patch_block(type)) == patch;
3335
3336 bool hidden = is_hidden_variable(var, incl_builtins);
3337
3338 // ClipDistance is never hidden, we need to emulate it when used as an input.
3339 if (bi_type == BuiltInClipDistance || bi_type == BuiltInCullDistance)
3340 hidden = false;
3341
3342 // It's not enough to simply avoid marking fragment outputs if the pipeline won't
3343 // accept them. We can't put them in the struct at all, or otherwise the compiler
3344 // complains that the outputs weren't explicitly marked.
3345 if (get_execution_model() == ExecutionModelFragment && storage == StorageClassOutput && !patch &&
3346 ((is_builtin && ((bi_type == BuiltInFragDepth && !msl_options.enable_frag_depth_builtin) ||
3347 (bi_type == BuiltInFragStencilRefEXT && !msl_options.enable_frag_stencil_ref_builtin))) ||
3348 (!is_builtin && !(msl_options.enable_frag_output_mask & (1 << location)))))
3349 {
3350 hidden = true;
3351 disabled_frag_outputs.push_back(var_id);
3352 // If a builtin, force it to have the proper name.
3353 if (is_builtin)
3354 set_name(var_id, builtin_to_glsl(bi_type, StorageClassFunction));
3355 }
3356
3357 // Barycentric inputs must be emitted in stage-in, because they can have interpolation arguments.
3358 if (is_active && (bi_type == BuiltInBaryCoordNV || bi_type == BuiltInBaryCoordNoPerspNV))
3359 {
3360 if (has_seen_barycentric)
3361 SPIRV_CROSS_THROW("Cannot declare both BaryCoordNV and BaryCoordNoPerspNV in same shader in MSL.");
3362 has_seen_barycentric = true;
3363 hidden = false;
3364 }
3365
3366 if (is_active && !hidden && type.pointer && filter_patch_decoration &&
3367 (!is_builtin || is_interface_block_builtin))
3368 {
3369 vars.push_back(&var);
3370
3371 if (!is_builtin)
3372 {
3373 // Need to deal specially with DecorationComponent.
3374 // Multiple variables can alias the same Location, and try to make sure each location is declared only once.
3375 // We will swizzle data in and out to make this work.
3376 // This is only relevant for vertex inputs and fragment outputs.
3377 // Technically tessellation as well, but it is too complicated to support.
3378 uint32_t component = get_decoration(var_id, DecorationComponent);
3379 if (component != 0)
3380 {
3381 if (is_tessellation_shader())
3382 SPIRV_CROSS_THROW("Component decoration is not supported in tessellation shaders.");
3383 else if (pack_components)
3384 {
3385 uint32_t array_size = 1;
3386 if (!type.array.empty())
3387 array_size = to_array_size_literal(type);
3388
3389 for (uint32_t location_offset = 0; location_offset < array_size; location_offset++)
3390 {
3391 auto &location_meta = meta.location_meta[location + location_offset];
3392 location_meta.num_components = std::max(location_meta.num_components, component + type.vecsize);
3393
3394 // For variables sharing location, decorations and base type must match.
3395 location_meta.base_type_id = type.self;
3396 location_meta.flat = has_decoration(var.self, DecorationFlat);
3397 location_meta.noperspective = has_decoration(var.self, DecorationNoPerspective);
3398 location_meta.centroid = has_decoration(var.self, DecorationCentroid);
3399 location_meta.sample = has_decoration(var.self, DecorationSample);
3400 }
3401 }
3402 }
3403 }
3404 }
3405 });
3406
3407 // If no variables qualify, leave.
3408 // For patch input in a tessellation evaluation shader, the per-vertex stage inputs
3409 // are included in a special patch control point array.
3410 if (vars.empty() && !(storage == StorageClassInput && patch && stage_in_var_id))
3411 return 0;
3412
3413 // Add a new typed variable for this interface structure.
3414 // The initializer expression is allocated here, but populated when the function
3415 // declaraion is emitted, because it is cleared after each compilation pass.
3416 uint32_t next_id = ir.increase_bound_by(3);
3417 uint32_t ib_type_id = next_id++;
3418 auto &ib_type = set<SPIRType>(ib_type_id);
3419 ib_type.basetype = SPIRType::Struct;
3420 ib_type.storage = storage;
3421 set_decoration(ib_type_id, DecorationBlock);
3422
3423 uint32_t ib_var_id = next_id++;
3424 auto &var = set<SPIRVariable>(ib_var_id, ib_type_id, storage, 0);
3425 var.initializer = next_id++;
3426
3427 string ib_var_ref;
3428 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
3429 switch (storage)
3430 {
3431 case StorageClassInput:
3432 ib_var_ref = patch ? patch_stage_in_var_name : stage_in_var_name;
3433 if (get_execution_model() == ExecutionModelTessellationControl)
3434 {
3435 // Add a hook to populate the shared workgroup memory containing the gl_in array.
3436 entry_func.fixup_hooks_in.push_back([=]() {
3437 // Can't use PatchVertices, PrimitiveId, or InvocationId yet; the hooks for those may not have run yet.
3438 if (msl_options.multi_patch_workgroup)
3439 {
3440 // n.b. builtin_invocation_id_id here is the dispatch global invocation ID,
3441 // not the TC invocation ID.
3442 statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "* gl_in = &",
3443 input_buffer_var_name, "[min(", to_expression(builtin_invocation_id_id), ".x / ",
3444 get_entry_point().output_vertices,
3445 ", spvIndirectParams[1] - 1) * spvIndirectParams[0]];");
3446 }
3447 else
3448 {
3449 // It's safe to use InvocationId here because it's directly mapped to a
3450 // Metal builtin, and therefore doesn't need a hook.
3451 statement("if (", to_expression(builtin_invocation_id_id), " < spvIndirectParams[0])");
3452 statement(" ", input_wg_var_name, "[", to_expression(builtin_invocation_id_id),
3453 "] = ", ib_var_ref, ";");
3454 statement("threadgroup_barrier(mem_flags::mem_threadgroup);");
3455 statement("if (", to_expression(builtin_invocation_id_id),
3456 " >= ", get_entry_point().output_vertices, ")");
3457 statement(" return;");
3458 }
3459 });
3460 }
3461 break;
3462
3463 case StorageClassOutput:
3464 {
3465 ib_var_ref = patch ? patch_stage_out_var_name : stage_out_var_name;
3466
3467 // Add the output interface struct as a local variable to the entry function.
3468 // If the entry point should return the output struct, set the entry function
3469 // to return the output interface struct, otherwise to return nothing.
3470 // Indicate the output var requires early initialization.
3471 bool ep_should_return_output = !get_is_rasterization_disabled();
3472 uint32_t rtn_id = ep_should_return_output ? ib_var_id : 0;
3473 if (!capture_output_to_buffer)
3474 {
3475 entry_func.add_local_variable(ib_var_id);
3476 for (auto &blk_id : entry_func.blocks)
3477 {
3478 auto &blk = get<SPIRBlock>(blk_id);
3479 if (blk.terminator == SPIRBlock::Return)
3480 blk.return_value = rtn_id;
3481 }
3482 vars_needing_early_declaration.push_back(ib_var_id);
3483 }
3484 else
3485 {
3486 switch (get_execution_model())
3487 {
3488 case ExecutionModelVertex:
3489 case ExecutionModelTessellationEvaluation:
3490 // Instead of declaring a struct variable to hold the output and then
3491 // copying that to the output buffer, we'll declare the output variable
3492 // as a reference to the final output element in the buffer. Then we can
3493 // avoid the extra copy.
3494 entry_func.fixup_hooks_in.push_back([=]() {
3495 if (stage_out_var_id)
3496 {
3497 // The first member of the indirect buffer is always the number of vertices
3498 // to draw.
3499 // We zero-base the InstanceID & VertexID variables for HLSL emulation elsewhere, so don't do it twice
3500 if (get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation)
3501 {
3502 statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref,
3503 " = ", output_buffer_var_name, "[", to_expression(builtin_invocation_id_id),
3504 ".y * ", to_expression(builtin_stage_input_size_id), ".x + ",
3505 to_expression(builtin_invocation_id_id), ".x];");
3506 }
3507 else if (msl_options.enable_base_index_zero)
3508 {
3509 statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref,
3510 " = ", output_buffer_var_name, "[", to_expression(builtin_instance_idx_id),
3511 " * spvIndirectParams[0] + ", to_expression(builtin_vertex_idx_id), "];");
3512 }
3513 else
3514 {
3515 statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref,
3516 " = ", output_buffer_var_name, "[(", to_expression(builtin_instance_idx_id),
3517 " - ", to_expression(builtin_base_instance_id), ") * spvIndirectParams[0] + ",
3518 to_expression(builtin_vertex_idx_id), " - ",
3519 to_expression(builtin_base_vertex_id), "];");
3520 }
3521 }
3522 });
3523 break;
3524 case ExecutionModelTessellationControl:
3525 if (msl_options.multi_patch_workgroup)
3526 {
3527 // We cannot use PrimitiveId here, because the hook may not have run yet.
3528 if (patch)
3529 {
3530 entry_func.fixup_hooks_in.push_back([=]() {
3531 statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref,
3532 " = ", patch_output_buffer_var_name, "[", to_expression(builtin_invocation_id_id),
3533 ".x / ", get_entry_point().output_vertices, "];");
3534 });
3535 }
3536 else
3537 {
3538 entry_func.fixup_hooks_in.push_back([=]() {
3539 statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "* gl_out = &",
3540 output_buffer_var_name, "[", to_expression(builtin_invocation_id_id), ".x - ",
3541 to_expression(builtin_invocation_id_id), ".x % ",
3542 get_entry_point().output_vertices, "];");
3543 });
3544 }
3545 }
3546 else
3547 {
3548 if (patch)
3549 {
3550 entry_func.fixup_hooks_in.push_back([=]() {
3551 statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref,
3552 " = ", patch_output_buffer_var_name, "[", to_expression(builtin_primitive_id_id),
3553 "];");
3554 });
3555 }
3556 else
3557 {
3558 entry_func.fixup_hooks_in.push_back([=]() {
3559 statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "* gl_out = &",
3560 output_buffer_var_name, "[", to_expression(builtin_primitive_id_id), " * ",
3561 get_entry_point().output_vertices, "];");
3562 });
3563 }
3564 }
3565 break;
3566 default:
3567 break;
3568 }
3569 }
3570 break;
3571 }
3572
3573 default:
3574 break;
3575 }
3576
3577 set_name(ib_type_id, to_name(ir.default_entry_point) + "_" + ib_var_ref);
3578 set_name(ib_var_id, ib_var_ref);
3579
3580 for (auto *p_var : vars)
3581 {
3582 bool strip_array =
3583 (get_execution_model() == ExecutionModelTessellationControl ||
3584 (get_execution_model() == ExecutionModelTessellationEvaluation && storage == StorageClassInput)) &&
3585 !patch;
3586
3587 // Fixing up flattened stores in TESC is impossible since the memory is group shared either via
3588 // device (not masked) or threadgroup (masked) storage classes and it's race condition city.
3589 meta.strip_array = strip_array;
3590 meta.allow_local_declaration = !strip_array && !(get_execution_model() == ExecutionModelTessellationControl &&
3591 storage == StorageClassOutput);
3592 add_variable_to_interface_block(storage, ib_var_ref, ib_type, *p_var, meta);
3593 }
3594
3595 if (get_execution_model() == ExecutionModelTessellationControl && msl_options.multi_patch_workgroup &&
3596 storage == StorageClassInput)
3597 {
3598 // For tessellation control inputs, add all outputs from the vertex shader to ensure
3599 // the struct containing them is the correct size and layout.
3600 for (auto &input : inputs_by_location)
3601 {
3602 if (location_inputs_in_use.count(input.first) != 0)
3603 continue;
3604
3605 // Create a fake variable to put at the location.
3606 uint32_t offset = ir.increase_bound_by(4);
3607 uint32_t type_id = offset;
3608 uint32_t array_type_id = offset + 1;
3609 uint32_t ptr_type_id = offset + 2;
3610 uint32_t var_id = offset + 3;
3611
3612 SPIRType type;
3613 switch (input.second.format)
3614 {
3615 case MSL_SHADER_INPUT_FORMAT_UINT16:
3616 case MSL_SHADER_INPUT_FORMAT_ANY16:
3617 type.basetype = SPIRType::UShort;
3618 type.width = 16;
3619 break;
3620 case MSL_SHADER_INPUT_FORMAT_ANY32:
3621 default:
3622 type.basetype = SPIRType::UInt;
3623 type.width = 32;
3624 break;
3625 }
3626 type.vecsize = input.second.vecsize;
3627 set<SPIRType>(type_id, type);
3628
3629 type.array.push_back(0);
3630 type.array_size_literal.push_back(true);
3631 type.parent_type = type_id;
3632 set<SPIRType>(array_type_id, type);
3633
3634 type.pointer = true;
3635 type.pointer_depth++;
3636 type.parent_type = array_type_id;
3637 type.storage = storage;
3638 auto &ptr_type = set<SPIRType>(ptr_type_id, type);
3639 ptr_type.self = array_type_id;
3640
3641 auto &fake_var = set<SPIRVariable>(var_id, ptr_type_id, storage);
3642 set_decoration(var_id, DecorationLocation, input.first);
3643 meta.strip_array = true;
3644 meta.allow_local_declaration = false;
3645 add_variable_to_interface_block(storage, ib_var_ref, ib_type, fake_var, meta);
3646 }
3647 }
3648
3649 // When multiple variables need to access same location,
3650 // unroll locations one by one and we will flatten output or input as necessary.
3651 for (auto &loc : meta.location_meta)
3652 {
3653 uint32_t location = loc.first;
3654 auto &location_meta = loc.second;
3655
3656 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
3657 uint32_t type_id = build_extended_vector_type(location_meta.base_type_id, location_meta.num_components);
3658 ib_type.member_types.push_back(type_id);
3659
3660 set_member_name(ib_type.self, ib_mbr_idx, join("m_location_", location));
3661 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, location);
3662 mark_location_as_used_by_shader(location, get<SPIRType>(type_id), storage);
3663
3664 if (location_meta.flat)
3665 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat);
3666 if (location_meta.noperspective)
3667 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective);
3668 if (location_meta.centroid)
3669 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid);
3670 if (location_meta.sample)
3671 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample);
3672 }
3673
3674 // Sort the members of the structure by their locations.
3675 MemberSorter member_sorter(ib_type, ir.meta[ib_type_id], MemberSorter::LocationThenBuiltInType);
3676 member_sorter.sort();
3677
3678 // The member indices were saved to the original variables, but after the members
3679 // were sorted, those indices are now likely incorrect. Fix those up now.
3680 fix_up_interface_member_indices(storage, ib_type_id);
3681
3682 // For patch inputs, add one more member, holding the array of control point data.
3683 if (get_execution_model() == ExecutionModelTessellationEvaluation && storage == StorageClassInput && patch &&
3684 stage_in_var_id)
3685 {
3686 uint32_t pcp_type_id = ir.increase_bound_by(1);
3687 auto &pcp_type = set<SPIRType>(pcp_type_id, ib_type);
3688 pcp_type.basetype = SPIRType::ControlPointArray;
3689 pcp_type.parent_type = pcp_type.type_alias = get_stage_in_struct_type().self;
3690 pcp_type.storage = storage;
3691 ir.meta[pcp_type_id] = ir.meta[ib_type.self];
3692 uint32_t mbr_idx = uint32_t(ib_type.member_types.size());
3693 ib_type.member_types.push_back(pcp_type_id);
3694 set_member_name(ib_type.self, mbr_idx, "gl_in");
3695 }
3696
3697 return ib_var_id;
3698 }
3699
add_interface_block_pointer(uint32_t ib_var_id,StorageClass storage)3700 uint32_t CompilerMSL::add_interface_block_pointer(uint32_t ib_var_id, StorageClass storage)
3701 {
3702 if (!ib_var_id)
3703 return 0;
3704
3705 uint32_t ib_ptr_var_id;
3706 uint32_t next_id = ir.increase_bound_by(3);
3707 auto &ib_type = expression_type(ib_var_id);
3708 if (get_execution_model() == ExecutionModelTessellationControl)
3709 {
3710 // Tessellation control per-vertex I/O is presented as an array, so we must
3711 // do the same with our struct here.
3712 uint32_t ib_ptr_type_id = next_id++;
3713 auto &ib_ptr_type = set<SPIRType>(ib_ptr_type_id, ib_type);
3714 ib_ptr_type.parent_type = ib_ptr_type.type_alias = ib_type.self;
3715 ib_ptr_type.pointer = true;
3716 ib_ptr_type.pointer_depth++;
3717 ib_ptr_type.storage =
3718 storage == StorageClassInput ?
3719 (msl_options.multi_patch_workgroup ? StorageClassStorageBuffer : StorageClassWorkgroup) :
3720 StorageClassStorageBuffer;
3721 ir.meta[ib_ptr_type_id] = ir.meta[ib_type.self];
3722 // To ensure that get_variable_data_type() doesn't strip off the pointer,
3723 // which we need, use another pointer.
3724 uint32_t ib_ptr_ptr_type_id = next_id++;
3725 auto &ib_ptr_ptr_type = set<SPIRType>(ib_ptr_ptr_type_id, ib_ptr_type);
3726 ib_ptr_ptr_type.parent_type = ib_ptr_type_id;
3727 ib_ptr_ptr_type.type_alias = ib_type.self;
3728 ib_ptr_ptr_type.storage = StorageClassFunction;
3729 ir.meta[ib_ptr_ptr_type_id] = ir.meta[ib_type.self];
3730
3731 ib_ptr_var_id = next_id;
3732 set<SPIRVariable>(ib_ptr_var_id, ib_ptr_ptr_type_id, StorageClassFunction, 0);
3733 set_name(ib_ptr_var_id, storage == StorageClassInput ? "gl_in" : "gl_out");
3734 }
3735 else
3736 {
3737 // Tessellation evaluation per-vertex inputs are also presented as arrays.
3738 // But, in Metal, this array uses a very special type, 'patch_control_point<T>',
3739 // which is a container that can be used to access the control point data.
3740 // To represent this, a special 'ControlPointArray' type has been added to the
3741 // SPIRV-Cross type system. It should only be generated by and seen in the MSL
3742 // backend (i.e. this one).
3743 uint32_t pcp_type_id = next_id++;
3744 auto &pcp_type = set<SPIRType>(pcp_type_id, ib_type);
3745 pcp_type.basetype = SPIRType::ControlPointArray;
3746 pcp_type.parent_type = pcp_type.type_alias = ib_type.self;
3747 pcp_type.storage = storage;
3748 ir.meta[pcp_type_id] = ir.meta[ib_type.self];
3749
3750 ib_ptr_var_id = next_id;
3751 set<SPIRVariable>(ib_ptr_var_id, pcp_type_id, storage, 0);
3752 set_name(ib_ptr_var_id, "gl_in");
3753 ir.meta[ib_ptr_var_id].decoration.qualified_alias = join(patch_stage_in_var_name, ".gl_in");
3754 }
3755 return ib_ptr_var_id;
3756 }
3757
3758 // Ensure that the type is compatible with the builtin.
3759 // If it is, simply return the given type ID.
3760 // Otherwise, create a new type, and return it's ID.
ensure_correct_builtin_type(uint32_t type_id,BuiltIn builtin)3761 uint32_t CompilerMSL::ensure_correct_builtin_type(uint32_t type_id, BuiltIn builtin)
3762 {
3763 auto &type = get<SPIRType>(type_id);
3764
3765 if ((builtin == BuiltInSampleMask && is_array(type)) ||
3766 ((builtin == BuiltInLayer || builtin == BuiltInViewportIndex || builtin == BuiltInFragStencilRefEXT) &&
3767 type.basetype != SPIRType::UInt))
3768 {
3769 uint32_t next_id = ir.increase_bound_by(type.pointer ? 2 : 1);
3770 uint32_t base_type_id = next_id++;
3771 auto &base_type = set<SPIRType>(base_type_id);
3772 base_type.basetype = SPIRType::UInt;
3773 base_type.width = 32;
3774
3775 if (!type.pointer)
3776 return base_type_id;
3777
3778 uint32_t ptr_type_id = next_id++;
3779 auto &ptr_type = set<SPIRType>(ptr_type_id);
3780 ptr_type = base_type;
3781 ptr_type.pointer = true;
3782 ptr_type.pointer_depth++;
3783 ptr_type.storage = type.storage;
3784 ptr_type.parent_type = base_type_id;
3785 return ptr_type_id;
3786 }
3787
3788 return type_id;
3789 }
3790
3791 // Ensure that the type is compatible with the shader input.
3792 // If it is, simply return the given type ID.
3793 // Otherwise, create a new type, and return its ID.
ensure_correct_input_type(uint32_t type_id,uint32_t location,uint32_t num_components,bool strip_array)3794 uint32_t CompilerMSL::ensure_correct_input_type(uint32_t type_id, uint32_t location, uint32_t num_components, bool strip_array)
3795 {
3796 auto &type = get<SPIRType>(type_id);
3797
3798 uint32_t max_array_dimensions = strip_array ? 1 : 0;
3799
3800 // Struct and array types must match exactly.
3801 if (type.basetype == SPIRType::Struct || type.array.size() > max_array_dimensions)
3802 return type_id;
3803
3804 auto p_va = inputs_by_location.find(location);
3805 if (p_va == end(inputs_by_location))
3806 {
3807 if (num_components > type.vecsize)
3808 return build_extended_vector_type(type_id, num_components);
3809 else
3810 return type_id;
3811 }
3812
3813 if (num_components == 0)
3814 num_components = p_va->second.vecsize;
3815
3816 switch (p_va->second.format)
3817 {
3818 case MSL_SHADER_INPUT_FORMAT_UINT8:
3819 {
3820 switch (type.basetype)
3821 {
3822 case SPIRType::UByte:
3823 case SPIRType::UShort:
3824 case SPIRType::UInt:
3825 if (num_components > type.vecsize)
3826 return build_extended_vector_type(type_id, num_components);
3827 else
3828 return type_id;
3829
3830 case SPIRType::Short:
3831 return build_extended_vector_type(type_id, num_components > type.vecsize ? num_components : type.vecsize,
3832 SPIRType::UShort);
3833 case SPIRType::Int:
3834 return build_extended_vector_type(type_id, num_components > type.vecsize ? num_components : type.vecsize,
3835 SPIRType::UInt);
3836
3837 default:
3838 SPIRV_CROSS_THROW("Vertex attribute type mismatch between host and shader");
3839 }
3840 }
3841
3842 case MSL_SHADER_INPUT_FORMAT_UINT16:
3843 {
3844 switch (type.basetype)
3845 {
3846 case SPIRType::UShort:
3847 case SPIRType::UInt:
3848 if (num_components > type.vecsize)
3849 return build_extended_vector_type(type_id, num_components);
3850 else
3851 return type_id;
3852
3853 case SPIRType::Int:
3854 return build_extended_vector_type(type_id, num_components > type.vecsize ? num_components : type.vecsize,
3855 SPIRType::UInt);
3856
3857 default:
3858 SPIRV_CROSS_THROW("Vertex attribute type mismatch between host and shader");
3859 }
3860 }
3861
3862 default:
3863 if (num_components > type.vecsize)
3864 type_id = build_extended_vector_type(type_id, num_components);
3865 break;
3866 }
3867
3868 return type_id;
3869 }
3870
mark_struct_members_packed(const SPIRType & type)3871 void CompilerMSL::mark_struct_members_packed(const SPIRType &type)
3872 {
3873 set_extended_decoration(type.self, SPIRVCrossDecorationPhysicalTypePacked);
3874
3875 // Problem case! Struct needs to be placed at an awkward alignment.
3876 // Mark every member of the child struct as packed.
3877 uint32_t mbr_cnt = uint32_t(type.member_types.size());
3878 for (uint32_t i = 0; i < mbr_cnt; i++)
3879 {
3880 auto &mbr_type = get<SPIRType>(type.member_types[i]);
3881 if (mbr_type.basetype == SPIRType::Struct)
3882 {
3883 // Recursively mark structs as packed.
3884 auto *struct_type = &mbr_type;
3885 while (!struct_type->array.empty())
3886 struct_type = &get<SPIRType>(struct_type->parent_type);
3887 mark_struct_members_packed(*struct_type);
3888 }
3889 else if (!is_scalar(mbr_type))
3890 set_extended_member_decoration(type.self, i, SPIRVCrossDecorationPhysicalTypePacked);
3891 }
3892 }
3893
mark_scalar_layout_structs(const SPIRType & type)3894 void CompilerMSL::mark_scalar_layout_structs(const SPIRType &type)
3895 {
3896 uint32_t mbr_cnt = uint32_t(type.member_types.size());
3897 for (uint32_t i = 0; i < mbr_cnt; i++)
3898 {
3899 auto &mbr_type = get<SPIRType>(type.member_types[i]);
3900 if (mbr_type.basetype == SPIRType::Struct)
3901 {
3902 auto *struct_type = &mbr_type;
3903 while (!struct_type->array.empty())
3904 struct_type = &get<SPIRType>(struct_type->parent_type);
3905
3906 if (has_extended_decoration(struct_type->self, SPIRVCrossDecorationPhysicalTypePacked))
3907 continue;
3908
3909 uint32_t msl_alignment = get_declared_struct_member_alignment_msl(type, i);
3910 uint32_t msl_size = get_declared_struct_member_size_msl(type, i);
3911 uint32_t spirv_offset = type_struct_member_offset(type, i);
3912 uint32_t spirv_offset_next;
3913 if (i + 1 < mbr_cnt)
3914 spirv_offset_next = type_struct_member_offset(type, i + 1);
3915 else
3916 spirv_offset_next = spirv_offset + msl_size;
3917
3918 // Both are complicated cases. In scalar layout, a struct of float3 might just consume 12 bytes,
3919 // and the next member will be placed at offset 12.
3920 bool struct_is_misaligned = (spirv_offset % msl_alignment) != 0;
3921 bool struct_is_too_large = spirv_offset + msl_size > spirv_offset_next;
3922 uint32_t array_stride = 0;
3923 bool struct_needs_explicit_padding = false;
3924
3925 // Verify that if a struct is used as an array that ArrayStride matches the effective size of the struct.
3926 if (!mbr_type.array.empty())
3927 {
3928 array_stride = type_struct_member_array_stride(type, i);
3929 uint32_t dimensions = uint32_t(mbr_type.array.size() - 1);
3930 for (uint32_t dim = 0; dim < dimensions; dim++)
3931 {
3932 uint32_t array_size = to_array_size_literal(mbr_type, dim);
3933 array_stride /= max(array_size, 1u);
3934 }
3935
3936 // Set expected struct size based on ArrayStride.
3937 struct_needs_explicit_padding = true;
3938
3939 // If struct size is larger than array stride, we might be able to fit, if we tightly pack.
3940 if (get_declared_struct_size_msl(*struct_type) > array_stride)
3941 struct_is_too_large = true;
3942 }
3943
3944 if (struct_is_misaligned || struct_is_too_large)
3945 mark_struct_members_packed(*struct_type);
3946 mark_scalar_layout_structs(*struct_type);
3947
3948 if (struct_needs_explicit_padding)
3949 {
3950 msl_size = get_declared_struct_size_msl(*struct_type, true, true);
3951 if (array_stride < msl_size)
3952 {
3953 SPIRV_CROSS_THROW("Cannot express an array stride smaller than size of struct type.");
3954 }
3955 else
3956 {
3957 if (has_extended_decoration(struct_type->self, SPIRVCrossDecorationPaddingTarget))
3958 {
3959 if (array_stride !=
3960 get_extended_decoration(struct_type->self, SPIRVCrossDecorationPaddingTarget))
3961 SPIRV_CROSS_THROW(
3962 "A struct is used with different array strides. Cannot express this in MSL.");
3963 }
3964 else
3965 set_extended_decoration(struct_type->self, SPIRVCrossDecorationPaddingTarget, array_stride);
3966 }
3967 }
3968 }
3969 }
3970 }
3971
3972 // Sort the members of the struct type by offset, and pack and then pad members where needed
3973 // to align MSL members with SPIR-V offsets. The struct members are iterated twice. Packing
3974 // occurs first, followed by padding, because packing a member reduces both its size and its
3975 // natural alignment, possibly requiring a padding member to be added ahead of it.
align_struct(SPIRType & ib_type,unordered_set<uint32_t> & aligned_structs)3976 void CompilerMSL::align_struct(SPIRType &ib_type, unordered_set<uint32_t> &aligned_structs)
3977 {
3978 // We align structs recursively, so stop any redundant work.
3979 ID &ib_type_id = ib_type.self;
3980 if (aligned_structs.count(ib_type_id))
3981 return;
3982 aligned_structs.insert(ib_type_id);
3983
3984 // Sort the members of the interface structure by their offset.
3985 // They should already be sorted per SPIR-V spec anyway.
3986 MemberSorter member_sorter(ib_type, ir.meta[ib_type_id], MemberSorter::Offset);
3987 member_sorter.sort();
3988
3989 auto mbr_cnt = uint32_t(ib_type.member_types.size());
3990
3991 for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
3992 {
3993 // Pack any dependent struct types before we pack a parent struct.
3994 auto &mbr_type = get<SPIRType>(ib_type.member_types[mbr_idx]);
3995 if (mbr_type.basetype == SPIRType::Struct)
3996 align_struct(mbr_type, aligned_structs);
3997 }
3998
3999 // Test the alignment of each member, and if a member should be closer to the previous
4000 // member than the default spacing expects, it is likely that the previous member is in
4001 // a packed format. If so, and the previous member is packable, pack it.
4002 // For example ... this applies to any 3-element vector that is followed by a scalar.
4003 uint32_t msl_offset = 0;
4004 for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
4005 {
4006 // This checks the member in isolation, if the member needs some kind of type remapping to conform to SPIR-V
4007 // offsets, array strides and matrix strides.
4008 ensure_member_packing_rules_msl(ib_type, mbr_idx);
4009
4010 // Align current offset to the current member's default alignment. If the member was packed, it will observe
4011 // the updated alignment here.
4012 uint32_t msl_align_mask = get_declared_struct_member_alignment_msl(ib_type, mbr_idx) - 1;
4013 uint32_t aligned_msl_offset = (msl_offset + msl_align_mask) & ~msl_align_mask;
4014
4015 // Fetch the member offset as declared in the SPIRV.
4016 uint32_t spirv_mbr_offset = get_member_decoration(ib_type_id, mbr_idx, DecorationOffset);
4017 if (spirv_mbr_offset > aligned_msl_offset)
4018 {
4019 // Since MSL and SPIR-V have slightly different struct member alignment and
4020 // size rules, we'll pad to standard C-packing rules with a char[] array. If the member is farther
4021 // away than C-packing, expects, add an inert padding member before the the member.
4022 uint32_t padding_bytes = spirv_mbr_offset - aligned_msl_offset;
4023 set_extended_member_decoration(ib_type_id, mbr_idx, SPIRVCrossDecorationPaddingTarget, padding_bytes);
4024
4025 // Re-align as a sanity check that aligning post-padding matches up.
4026 msl_offset += padding_bytes;
4027 aligned_msl_offset = (msl_offset + msl_align_mask) & ~msl_align_mask;
4028 }
4029 else if (spirv_mbr_offset < aligned_msl_offset)
4030 {
4031 // This should not happen, but deal with unexpected scenarios.
4032 // It *might* happen if a sub-struct has a larger alignment requirement in MSL than SPIR-V.
4033 SPIRV_CROSS_THROW("Cannot represent buffer block correctly in MSL.");
4034 }
4035
4036 assert(aligned_msl_offset == spirv_mbr_offset);
4037
4038 // Increment the current offset to be positioned immediately after the current member.
4039 // Don't do this for the last member since it can be unsized, and it is not relevant for padding purposes here.
4040 if (mbr_idx + 1 < mbr_cnt)
4041 msl_offset = aligned_msl_offset + get_declared_struct_member_size_msl(ib_type, mbr_idx);
4042 }
4043 }
4044
validate_member_packing_rules_msl(const SPIRType & type,uint32_t index) const4045 bool CompilerMSL::validate_member_packing_rules_msl(const SPIRType &type, uint32_t index) const
4046 {
4047 auto &mbr_type = get<SPIRType>(type.member_types[index]);
4048 uint32_t spirv_offset = get_member_decoration(type.self, index, DecorationOffset);
4049
4050 if (index + 1 < type.member_types.size())
4051 {
4052 // First, we will check offsets. If SPIR-V offset + MSL size > SPIR-V offset of next member,
4053 // we *must* perform some kind of remapping, no way getting around it.
4054 // We can always pad after this member if necessary, so that case is fine.
4055 uint32_t spirv_offset_next = get_member_decoration(type.self, index + 1, DecorationOffset);
4056 assert(spirv_offset_next >= spirv_offset);
4057 uint32_t maximum_size = spirv_offset_next - spirv_offset;
4058 uint32_t msl_mbr_size = get_declared_struct_member_size_msl(type, index);
4059 if (msl_mbr_size > maximum_size)
4060 return false;
4061 }
4062
4063 if (!mbr_type.array.empty())
4064 {
4065 // If we have an array type, array stride must match exactly with SPIR-V.
4066
4067 // An exception to this requirement is if we have one array element.
4068 // This comes from DX scalar layout workaround.
4069 // If app tries to be cheeky and access the member out of bounds, this will not work, but this is the best we can do.
4070 // In OpAccessChain with logical memory models, access chains must be in-bounds in SPIR-V specification.
4071 bool relax_array_stride = mbr_type.array.back() == 1 && mbr_type.array_size_literal.back();
4072
4073 if (!relax_array_stride)
4074 {
4075 uint32_t spirv_array_stride = type_struct_member_array_stride(type, index);
4076 uint32_t msl_array_stride = get_declared_struct_member_array_stride_msl(type, index);
4077 if (spirv_array_stride != msl_array_stride)
4078 return false;
4079 }
4080 }
4081
4082 if (is_matrix(mbr_type))
4083 {
4084 // Need to check MatrixStride as well.
4085 uint32_t spirv_matrix_stride = type_struct_member_matrix_stride(type, index);
4086 uint32_t msl_matrix_stride = get_declared_struct_member_matrix_stride_msl(type, index);
4087 if (spirv_matrix_stride != msl_matrix_stride)
4088 return false;
4089 }
4090
4091 // Now, we check alignment.
4092 uint32_t msl_alignment = get_declared_struct_member_alignment_msl(type, index);
4093 if ((spirv_offset % msl_alignment) != 0)
4094 return false;
4095
4096 // We're in the clear.
4097 return true;
4098 }
4099
4100 // Here we need to verify that the member type we declare conforms to Offset, ArrayStride or MatrixStride restrictions.
4101 // If there is a mismatch, we need to emit remapped types, either normal types, or "packed_X" types.
4102 // In odd cases we need to emit packed and remapped types, for e.g. weird matrices or arrays with weird array strides.
ensure_member_packing_rules_msl(SPIRType & ib_type,uint32_t index)4103 void CompilerMSL::ensure_member_packing_rules_msl(SPIRType &ib_type, uint32_t index)
4104 {
4105 if (validate_member_packing_rules_msl(ib_type, index))
4106 return;
4107
4108 // We failed validation.
4109 // This case will be nightmare-ish to deal with. This could possibly happen if struct alignment does not quite
4110 // match up with what we want. Scalar block layout comes to mind here where we might have to work around the rule
4111 // that struct alignment == max alignment of all members and struct size depends on this alignment.
4112 auto &mbr_type = get<SPIRType>(ib_type.member_types[index]);
4113 if (mbr_type.basetype == SPIRType::Struct)
4114 SPIRV_CROSS_THROW("Cannot perform any repacking for structs when it is used as a member of another struct.");
4115
4116 // Perform remapping here.
4117 // There is nothing to be gained by using packed scalars, so don't attempt it.
4118 if (!is_scalar(ib_type))
4119 set_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypePacked);
4120
4121 // Try validating again, now with packed.
4122 if (validate_member_packing_rules_msl(ib_type, index))
4123 return;
4124
4125 // We're in deep trouble, and we need to create a new PhysicalType which matches up with what we expect.
4126 // A lot of work goes here ...
4127 // We will need remapping on Load and Store to translate the types between Logical and Physical.
4128
4129 // First, we check if we have small vector std140 array.
4130 // We detect this if we have an array of vectors, and array stride is greater than number of elements.
4131 if (!mbr_type.array.empty() && !is_matrix(mbr_type))
4132 {
4133 uint32_t array_stride = type_struct_member_array_stride(ib_type, index);
4134
4135 // Hack off array-of-arrays until we find the array stride per element we must have to make it work.
4136 uint32_t dimensions = uint32_t(mbr_type.array.size() - 1);
4137 for (uint32_t dim = 0; dim < dimensions; dim++)
4138 array_stride /= max(to_array_size_literal(mbr_type, dim), 1u);
4139
4140 uint32_t elems_per_stride = array_stride / (mbr_type.width / 8);
4141
4142 if (elems_per_stride == 3)
4143 SPIRV_CROSS_THROW("Cannot use ArrayStride of 3 elements in remapping scenarios.");
4144 else if (elems_per_stride > 4)
4145 SPIRV_CROSS_THROW("Cannot represent vectors with more than 4 elements in MSL.");
4146
4147 auto physical_type = mbr_type;
4148 physical_type.vecsize = elems_per_stride;
4149 physical_type.parent_type = 0;
4150 uint32_t type_id = ir.increase_bound_by(1);
4151 set<SPIRType>(type_id, physical_type);
4152 set_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypeID, type_id);
4153 set_decoration(type_id, DecorationArrayStride, array_stride);
4154
4155 // Remove packed_ for vectors of size 1, 2 and 4.
4156 if (has_extended_decoration(ib_type.self, SPIRVCrossDecorationPhysicalTypePacked))
4157 SPIRV_CROSS_THROW("Unable to remove packed decoration as entire struct must be fully packed. Do not mix "
4158 "scalar and std140 layout rules.");
4159 else
4160 unset_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypePacked);
4161 }
4162 else if (is_matrix(mbr_type))
4163 {
4164 // MatrixStride might be std140-esque.
4165 uint32_t matrix_stride = type_struct_member_matrix_stride(ib_type, index);
4166
4167 uint32_t elems_per_stride = matrix_stride / (mbr_type.width / 8);
4168
4169 if (elems_per_stride == 3)
4170 SPIRV_CROSS_THROW("Cannot use ArrayStride of 3 elements in remapping scenarios.");
4171 else if (elems_per_stride > 4)
4172 SPIRV_CROSS_THROW("Cannot represent vectors with more than 4 elements in MSL.");
4173
4174 bool row_major = has_member_decoration(ib_type.self, index, DecorationRowMajor);
4175
4176 auto physical_type = mbr_type;
4177 physical_type.parent_type = 0;
4178 if (row_major)
4179 physical_type.columns = elems_per_stride;
4180 else
4181 physical_type.vecsize = elems_per_stride;
4182 uint32_t type_id = ir.increase_bound_by(1);
4183 set<SPIRType>(type_id, physical_type);
4184 set_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypeID, type_id);
4185
4186 // Remove packed_ for vectors of size 1, 2 and 4.
4187 if (has_extended_decoration(ib_type.self, SPIRVCrossDecorationPhysicalTypePacked))
4188 SPIRV_CROSS_THROW("Unable to remove packed decoration as entire struct must be fully packed. Do not mix "
4189 "scalar and std140 layout rules.");
4190 else
4191 unset_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypePacked);
4192 }
4193 else
4194 SPIRV_CROSS_THROW("Found a buffer packing case which we cannot represent in MSL.");
4195
4196 // Try validating again, now with physical type remapping.
4197 if (validate_member_packing_rules_msl(ib_type, index))
4198 return;
4199
4200 // We might have a particular odd scalar layout case where the last element of an array
4201 // does not take up as much space as the ArrayStride or MatrixStride. This can happen with DX cbuffers.
4202 // The "proper" workaround for this is extremely painful and essentially impossible in the edge case of float3[],
4203 // so we hack around it by declaring the offending array or matrix with one less array size/col/row,
4204 // and rely on padding to get the correct value. We will technically access arrays out of bounds into the padding region,
4205 // but it should spill over gracefully without too much trouble. We rely on behavior like this for unsized arrays anyways.
4206
4207 // E.g. we might observe a physical layout of:
4208 // { float2 a[2]; float b; } in cbuffer layout where ArrayStride of a is 16, but offset of b is 24, packed right after a[1] ...
4209 uint32_t type_id = get_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypeID);
4210 auto &type = get<SPIRType>(type_id);
4211
4212 // Modify the physical type in-place. This is safe since each physical type workaround is a copy.
4213 if (is_array(type))
4214 {
4215 if (type.array.back() > 1)
4216 {
4217 if (!type.array_size_literal.back())
4218 SPIRV_CROSS_THROW("Cannot apply scalar layout workaround with spec constant array size.");
4219 type.array.back() -= 1;
4220 }
4221 else
4222 {
4223 // We have an array of size 1, so we cannot decrement that. Our only option now is to
4224 // force a packed layout instead, and drop the physical type remap since ArrayStride is meaningless now.
4225 unset_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypeID);
4226 set_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypePacked);
4227 }
4228 }
4229 else if (is_matrix(type))
4230 {
4231 bool row_major = has_member_decoration(ib_type.self, index, DecorationRowMajor);
4232 if (!row_major)
4233 {
4234 // Slice off one column. If we only have 2 columns, this might turn the matrix into a vector with one array element instead.
4235 if (type.columns > 2)
4236 {
4237 type.columns--;
4238 }
4239 else if (type.columns == 2)
4240 {
4241 type.columns = 1;
4242 assert(type.array.empty());
4243 type.array.push_back(1);
4244 type.array_size_literal.push_back(true);
4245 }
4246 }
4247 else
4248 {
4249 // Slice off one row. If we only have 2 rows, this might turn the matrix into a vector with one array element instead.
4250 if (type.vecsize > 2)
4251 {
4252 type.vecsize--;
4253 }
4254 else if (type.vecsize == 2)
4255 {
4256 type.vecsize = type.columns;
4257 type.columns = 1;
4258 assert(type.array.empty());
4259 type.array.push_back(1);
4260 type.array_size_literal.push_back(true);
4261 }
4262 }
4263 }
4264
4265 // This better validate now, or we must fail gracefully.
4266 if (!validate_member_packing_rules_msl(ib_type, index))
4267 SPIRV_CROSS_THROW("Found a buffer packing case which we cannot represent in MSL.");
4268 }
4269
emit_store_statement(uint32_t lhs_expression,uint32_t rhs_expression)4270 void CompilerMSL::emit_store_statement(uint32_t lhs_expression, uint32_t rhs_expression)
4271 {
4272 auto &type = expression_type(rhs_expression);
4273
4274 bool lhs_remapped_type = has_extended_decoration(lhs_expression, SPIRVCrossDecorationPhysicalTypeID);
4275 bool lhs_packed_type = has_extended_decoration(lhs_expression, SPIRVCrossDecorationPhysicalTypePacked);
4276 auto *lhs_e = maybe_get<SPIRExpression>(lhs_expression);
4277 auto *rhs_e = maybe_get<SPIRExpression>(rhs_expression);
4278
4279 bool transpose = lhs_e && lhs_e->need_transpose;
4280
4281 // No physical type remapping, and no packed type, so can just emit a store directly.
4282 if (!lhs_remapped_type && !lhs_packed_type)
4283 {
4284 // We might not be dealing with remapped physical types or packed types,
4285 // but we might be doing a clean store to a row-major matrix.
4286 // In this case, we just flip transpose states, and emit the store, a transpose must be in the RHS expression, if any.
4287 if (is_matrix(type) && lhs_e && lhs_e->need_transpose)
4288 {
4289 lhs_e->need_transpose = false;
4290
4291 if (rhs_e && rhs_e->need_transpose)
4292 {
4293 // Direct copy, but might need to unpack RHS.
4294 // Skip the transpose, as we will transpose when writing to LHS and transpose(transpose(T)) == T.
4295 rhs_e->need_transpose = false;
4296 statement(to_expression(lhs_expression), " = ", to_unpacked_row_major_matrix_expression(rhs_expression),
4297 ";");
4298 rhs_e->need_transpose = true;
4299 }
4300 else
4301 statement(to_expression(lhs_expression), " = transpose(", to_unpacked_expression(rhs_expression), ");");
4302
4303 lhs_e->need_transpose = true;
4304 register_write(lhs_expression);
4305 }
4306 else if (lhs_e && lhs_e->need_transpose)
4307 {
4308 lhs_e->need_transpose = false;
4309
4310 // Storing a column to a row-major matrix. Unroll the write.
4311 for (uint32_t c = 0; c < type.vecsize; c++)
4312 {
4313 auto lhs_expr = to_dereferenced_expression(lhs_expression);
4314 auto column_index = lhs_expr.find_last_of('[');
4315 if (column_index != string::npos)
4316 {
4317 statement(lhs_expr.insert(column_index, join('[', c, ']')), " = ",
4318 to_extract_component_expression(rhs_expression, c), ";");
4319 }
4320 }
4321 lhs_e->need_transpose = true;
4322 register_write(lhs_expression);
4323 }
4324 else
4325 CompilerGLSL::emit_store_statement(lhs_expression, rhs_expression);
4326 }
4327 else if (!lhs_remapped_type && !is_matrix(type) && !transpose)
4328 {
4329 // Even if the target type is packed, we can directly store to it. We cannot store to packed matrices directly,
4330 // since they are declared as array of vectors instead, and we need the fallback path below.
4331 CompilerGLSL::emit_store_statement(lhs_expression, rhs_expression);
4332 }
4333 else
4334 {
4335 // Special handling when storing to a remapped physical type.
4336 // This is mostly to deal with std140 padded matrices or vectors.
4337
4338 TypeID physical_type_id = lhs_remapped_type ?
4339 ID(get_extended_decoration(lhs_expression, SPIRVCrossDecorationPhysicalTypeID)) :
4340 type.self;
4341
4342 auto &physical_type = get<SPIRType>(physical_type_id);
4343
4344 if (is_matrix(type))
4345 {
4346 const char *packed_pfx = lhs_packed_type ? "packed_" : "";
4347
4348 // Packed matrices are stored as arrays of packed vectors, so we need
4349 // to assign the vectors one at a time.
4350 // For row-major matrices, we need to transpose the *right-hand* side,
4351 // not the left-hand side.
4352
4353 // Lots of cases to cover here ...
4354
4355 bool rhs_transpose = rhs_e && rhs_e->need_transpose;
4356 SPIRType write_type = type;
4357 string cast_expr;
4358
4359 // We're dealing with transpose manually.
4360 if (rhs_transpose)
4361 rhs_e->need_transpose = false;
4362
4363 if (transpose)
4364 {
4365 // We're dealing with transpose manually.
4366 lhs_e->need_transpose = false;
4367 write_type.vecsize = type.columns;
4368 write_type.columns = 1;
4369
4370 if (physical_type.columns != type.columns)
4371 cast_expr = join("(device ", packed_pfx, type_to_glsl(write_type), "&)");
4372
4373 if (rhs_transpose)
4374 {
4375 // If RHS is also transposed, we can just copy row by row.
4376 for (uint32_t i = 0; i < type.vecsize; i++)
4377 {
4378 statement(cast_expr, to_enclosed_expression(lhs_expression), "[", i, "]", " = ",
4379 to_unpacked_row_major_matrix_expression(rhs_expression), "[", i, "];");
4380 }
4381 }
4382 else
4383 {
4384 auto vector_type = expression_type(rhs_expression);
4385 vector_type.vecsize = vector_type.columns;
4386 vector_type.columns = 1;
4387
4388 // Transpose on the fly. Emitting a lot of full transpose() ops and extracting lanes seems very bad,
4389 // so pick out individual components instead.
4390 for (uint32_t i = 0; i < type.vecsize; i++)
4391 {
4392 string rhs_row = type_to_glsl_constructor(vector_type) + "(";
4393 for (uint32_t j = 0; j < vector_type.vecsize; j++)
4394 {
4395 rhs_row += join(to_enclosed_unpacked_expression(rhs_expression), "[", j, "][", i, "]");
4396 if (j + 1 < vector_type.vecsize)
4397 rhs_row += ", ";
4398 }
4399 rhs_row += ")";
4400
4401 statement(cast_expr, to_enclosed_expression(lhs_expression), "[", i, "]", " = ", rhs_row, ";");
4402 }
4403 }
4404
4405 // We're dealing with transpose manually.
4406 lhs_e->need_transpose = true;
4407 }
4408 else
4409 {
4410 write_type.columns = 1;
4411
4412 if (physical_type.vecsize != type.vecsize)
4413 cast_expr = join("(device ", packed_pfx, type_to_glsl(write_type), "&)");
4414
4415 if (rhs_transpose)
4416 {
4417 auto vector_type = expression_type(rhs_expression);
4418 vector_type.columns = 1;
4419
4420 // Transpose on the fly. Emitting a lot of full transpose() ops and extracting lanes seems very bad,
4421 // so pick out individual components instead.
4422 for (uint32_t i = 0; i < type.columns; i++)
4423 {
4424 string rhs_row = type_to_glsl_constructor(vector_type) + "(";
4425 for (uint32_t j = 0; j < vector_type.vecsize; j++)
4426 {
4427 // Need to explicitly unpack expression since we've mucked with transpose state.
4428 auto unpacked_expr = to_unpacked_row_major_matrix_expression(rhs_expression);
4429 rhs_row += join(unpacked_expr, "[", j, "][", i, "]");
4430 if (j + 1 < vector_type.vecsize)
4431 rhs_row += ", ";
4432 }
4433 rhs_row += ")";
4434
4435 statement(cast_expr, to_enclosed_expression(lhs_expression), "[", i, "]", " = ", rhs_row, ";");
4436 }
4437 }
4438 else
4439 {
4440 // Copy column-by-column.
4441 for (uint32_t i = 0; i < type.columns; i++)
4442 {
4443 statement(cast_expr, to_enclosed_expression(lhs_expression), "[", i, "]", " = ",
4444 to_enclosed_unpacked_expression(rhs_expression), "[", i, "];");
4445 }
4446 }
4447 }
4448
4449 // We're dealing with transpose manually.
4450 if (rhs_transpose)
4451 rhs_e->need_transpose = true;
4452 }
4453 else if (transpose)
4454 {
4455 lhs_e->need_transpose = false;
4456
4457 SPIRType write_type = type;
4458 write_type.vecsize = 1;
4459 write_type.columns = 1;
4460
4461 // Storing a column to a row-major matrix. Unroll the write.
4462 for (uint32_t c = 0; c < type.vecsize; c++)
4463 {
4464 auto lhs_expr = to_enclosed_expression(lhs_expression);
4465 auto column_index = lhs_expr.find_last_of('[');
4466 if (column_index != string::npos)
4467 {
4468 statement("((device ", type_to_glsl(write_type), "*)&",
4469 lhs_expr.insert(column_index, join('[', c, ']', ")")), " = ",
4470 to_extract_component_expression(rhs_expression, c), ";");
4471 }
4472 }
4473
4474 lhs_e->need_transpose = true;
4475 }
4476 else if ((is_matrix(physical_type) || is_array(physical_type)) && physical_type.vecsize > type.vecsize)
4477 {
4478 assert(type.vecsize >= 1 && type.vecsize <= 3);
4479
4480 // If we have packed types, we cannot use swizzled stores.
4481 // We could technically unroll the store for each element if needed.
4482 // When remapping to a std140 physical type, we always get float4,
4483 // and the packed decoration should always be removed.
4484 assert(!lhs_packed_type);
4485
4486 string lhs = to_dereferenced_expression(lhs_expression);
4487 string rhs = to_pointer_expression(rhs_expression);
4488
4489 // Unpack the expression so we can store to it with a float or float2.
4490 // It's still an l-value, so it's fine. Most other unpacking of expressions turn them into r-values instead.
4491 lhs = join("(device ", type_to_glsl(type), "&)", enclose_expression(lhs));
4492 if (!optimize_read_modify_write(expression_type(rhs_expression), lhs, rhs))
4493 statement(lhs, " = ", rhs, ";");
4494 }
4495 else if (!is_matrix(type))
4496 {
4497 string lhs = to_dereferenced_expression(lhs_expression);
4498 string rhs = to_pointer_expression(rhs_expression);
4499 if (!optimize_read_modify_write(expression_type(rhs_expression), lhs, rhs))
4500 statement(lhs, " = ", rhs, ";");
4501 }
4502
4503 register_write(lhs_expression);
4504 }
4505 }
4506
expression_ends_with(const string & expr_str,const std::string & ending)4507 static bool expression_ends_with(const string &expr_str, const std::string &ending)
4508 {
4509 if (expr_str.length() >= ending.length())
4510 return (expr_str.compare(expr_str.length() - ending.length(), ending.length(), ending) == 0);
4511 else
4512 return false;
4513 }
4514
4515 // Converts the format of the current expression from packed to unpacked,
4516 // by wrapping the expression in a constructor of the appropriate type.
4517 // Also, handle special physical ID remapping scenarios, similar to emit_store_statement().
unpack_expression_type(string expr_str,const SPIRType & type,uint32_t physical_type_id,bool packed,bool row_major)4518 string CompilerMSL::unpack_expression_type(string expr_str, const SPIRType &type, uint32_t physical_type_id,
4519 bool packed, bool row_major)
4520 {
4521 // Trivial case, nothing to do.
4522 if (physical_type_id == 0 && !packed)
4523 return expr_str;
4524
4525 const SPIRType *physical_type = nullptr;
4526 if (physical_type_id)
4527 physical_type = &get<SPIRType>(physical_type_id);
4528
4529 static const char *swizzle_lut[] = {
4530 ".x",
4531 ".xy",
4532 ".xyz",
4533 };
4534
4535 if (physical_type && is_vector(*physical_type) && is_array(*physical_type) &&
4536 physical_type->vecsize > type.vecsize && !expression_ends_with(expr_str, swizzle_lut[type.vecsize - 1]))
4537 {
4538 // std140 array cases for vectors.
4539 assert(type.vecsize >= 1 && type.vecsize <= 3);
4540 return enclose_expression(expr_str) + swizzle_lut[type.vecsize - 1];
4541 }
4542 else if (physical_type && is_matrix(*physical_type) && is_vector(type) && physical_type->vecsize > type.vecsize)
4543 {
4544 // Extract column from padded matrix.
4545 assert(type.vecsize >= 1 && type.vecsize <= 3);
4546 return enclose_expression(expr_str) + swizzle_lut[type.vecsize - 1];
4547 }
4548 else if (is_matrix(type))
4549 {
4550 // Packed matrices are stored as arrays of packed vectors. Unfortunately,
4551 // we can't just pass the array straight to the matrix constructor. We have to
4552 // pass each vector individually, so that they can be unpacked to normal vectors.
4553 if (!physical_type)
4554 physical_type = &type;
4555
4556 uint32_t vecsize = type.vecsize;
4557 uint32_t columns = type.columns;
4558 if (row_major)
4559 swap(vecsize, columns);
4560
4561 uint32_t physical_vecsize = row_major ? physical_type->columns : physical_type->vecsize;
4562
4563 const char *base_type = type.width == 16 ? "half" : "float";
4564 string unpack_expr = join(base_type, columns, "x", vecsize, "(");
4565
4566 const char *load_swiz = "";
4567
4568 if (physical_vecsize != vecsize)
4569 load_swiz = swizzle_lut[vecsize - 1];
4570
4571 for (uint32_t i = 0; i < columns; i++)
4572 {
4573 if (i > 0)
4574 unpack_expr += ", ";
4575
4576 if (packed)
4577 unpack_expr += join(base_type, physical_vecsize, "(", expr_str, "[", i, "]", ")", load_swiz);
4578 else
4579 unpack_expr += join(expr_str, "[", i, "]", load_swiz);
4580 }
4581
4582 unpack_expr += ")";
4583 return unpack_expr;
4584 }
4585 else
4586 {
4587 return join(type_to_glsl(type), "(", expr_str, ")");
4588 }
4589 }
4590
4591 // Emits the file header info
emit_header()4592 void CompilerMSL::emit_header()
4593 {
4594 // This particular line can be overridden during compilation, so make it a flag and not a pragma line.
4595 if (suppress_missing_prototypes)
4596 statement("#pragma clang diagnostic ignored \"-Wmissing-prototypes\"");
4597
4598 // Disable warning about missing braces for array<T> template to make arrays a value type
4599 if (spv_function_implementations.count(SPVFuncImplUnsafeArray) != 0)
4600 statement("#pragma clang diagnostic ignored \"-Wmissing-braces\"");
4601
4602 for (auto &pragma : pragma_lines)
4603 statement(pragma);
4604
4605 if (!pragma_lines.empty() || suppress_missing_prototypes)
4606 statement("");
4607
4608 statement("#include <metal_stdlib>");
4609 statement("#include <simd/simd.h>");
4610
4611 for (auto &header : header_lines)
4612 statement(header);
4613
4614 statement("");
4615 statement("using namespace metal;");
4616 statement("");
4617
4618 for (auto &td : typedef_lines)
4619 statement(td);
4620
4621 if (!typedef_lines.empty())
4622 statement("");
4623 }
4624
add_pragma_line(const string & line)4625 void CompilerMSL::add_pragma_line(const string &line)
4626 {
4627 auto rslt = pragma_lines.insert(line);
4628 if (rslt.second)
4629 force_recompile();
4630 }
4631
add_typedef_line(const string & line)4632 void CompilerMSL::add_typedef_line(const string &line)
4633 {
4634 auto rslt = typedef_lines.insert(line);
4635 if (rslt.second)
4636 force_recompile();
4637 }
4638
4639 // Template struct like spvUnsafeArray<> need to be declared *before* any resources are declared
emit_custom_templates()4640 void CompilerMSL::emit_custom_templates()
4641 {
4642 for (const auto &spv_func : spv_function_implementations)
4643 {
4644 switch (spv_func)
4645 {
4646 case SPVFuncImplUnsafeArray:
4647 statement("template<typename T, size_t Num>");
4648 statement("struct spvUnsafeArray");
4649 begin_scope();
4650 statement("T elements[Num ? Num : 1];");
4651 statement("");
4652 statement("thread T& operator [] (size_t pos) thread");
4653 begin_scope();
4654 statement("return elements[pos];");
4655 end_scope();
4656 statement("constexpr const thread T& operator [] (size_t pos) const thread");
4657 begin_scope();
4658 statement("return elements[pos];");
4659 end_scope();
4660 statement("");
4661 statement("device T& operator [] (size_t pos) device");
4662 begin_scope();
4663 statement("return elements[pos];");
4664 end_scope();
4665 statement("constexpr const device T& operator [] (size_t pos) const device");
4666 begin_scope();
4667 statement("return elements[pos];");
4668 end_scope();
4669 statement("");
4670 statement("constexpr const constant T& operator [] (size_t pos) const constant");
4671 begin_scope();
4672 statement("return elements[pos];");
4673 end_scope();
4674 statement("");
4675 statement("threadgroup T& operator [] (size_t pos) threadgroup");
4676 begin_scope();
4677 statement("return elements[pos];");
4678 end_scope();
4679 statement("constexpr const threadgroup T& operator [] (size_t pos) const threadgroup");
4680 begin_scope();
4681 statement("return elements[pos];");
4682 end_scope();
4683 end_scope_decl();
4684 statement("");
4685 break;
4686
4687 default:
4688 break;
4689 }
4690 }
4691 }
4692
4693 // Emits any needed custom function bodies.
4694 // Metal helper functions must be static force-inline, i.e. static inline __attribute__((always_inline))
4695 // otherwise they will cause problems when linked together in a single Metallib.
emit_custom_functions()4696 void CompilerMSL::emit_custom_functions()
4697 {
4698 for (uint32_t i = kArrayCopyMultidimMax; i >= 2; i--)
4699 if (spv_function_implementations.count(static_cast<SPVFuncImpl>(SPVFuncImplArrayCopyMultidimBase + i)))
4700 spv_function_implementations.insert(static_cast<SPVFuncImpl>(SPVFuncImplArrayCopyMultidimBase + i - 1));
4701
4702 if (spv_function_implementations.count(SPVFuncImplDynamicImageSampler))
4703 {
4704 // Unfortunately, this one needs a lot of the other functions to compile OK.
4705 if (!msl_options.supports_msl_version(2))
4706 SPIRV_CROSS_THROW(
4707 "spvDynamicImageSampler requires default-constructible texture objects, which require MSL 2.0.");
4708 spv_function_implementations.insert(SPVFuncImplForwardArgs);
4709 spv_function_implementations.insert(SPVFuncImplTextureSwizzle);
4710 if (msl_options.swizzle_texture_samples)
4711 spv_function_implementations.insert(SPVFuncImplGatherSwizzle);
4712 for (uint32_t i = SPVFuncImplChromaReconstructNearest2Plane;
4713 i <= SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint3Plane; i++)
4714 spv_function_implementations.insert(static_cast<SPVFuncImpl>(i));
4715 spv_function_implementations.insert(SPVFuncImplExpandITUFullRange);
4716 spv_function_implementations.insert(SPVFuncImplExpandITUNarrowRange);
4717 spv_function_implementations.insert(SPVFuncImplConvertYCbCrBT709);
4718 spv_function_implementations.insert(SPVFuncImplConvertYCbCrBT601);
4719 spv_function_implementations.insert(SPVFuncImplConvertYCbCrBT2020);
4720 }
4721
4722 for (uint32_t i = SPVFuncImplChromaReconstructNearest2Plane;
4723 i <= SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint3Plane; i++)
4724 if (spv_function_implementations.count(static_cast<SPVFuncImpl>(i)))
4725 spv_function_implementations.insert(SPVFuncImplForwardArgs);
4726
4727 if (spv_function_implementations.count(SPVFuncImplTextureSwizzle) ||
4728 spv_function_implementations.count(SPVFuncImplGatherSwizzle) ||
4729 spv_function_implementations.count(SPVFuncImplGatherCompareSwizzle))
4730 {
4731 spv_function_implementations.insert(SPVFuncImplForwardArgs);
4732 spv_function_implementations.insert(SPVFuncImplGetSwizzle);
4733 }
4734
4735 for (const auto &spv_func : spv_function_implementations)
4736 {
4737 switch (spv_func)
4738 {
4739 case SPVFuncImplMod:
4740 statement("// Implementation of the GLSL mod() function, which is slightly different than Metal fmod()");
4741 statement("template<typename Tx, typename Ty>");
4742 statement("inline Tx mod(Tx x, Ty y)");
4743 begin_scope();
4744 statement("return x - y * floor(x / y);");
4745 end_scope();
4746 statement("");
4747 break;
4748
4749 case SPVFuncImplRadians:
4750 statement("// Implementation of the GLSL radians() function");
4751 statement("template<typename T>");
4752 statement("inline T radians(T d)");
4753 begin_scope();
4754 statement("return d * T(0.01745329251);");
4755 end_scope();
4756 statement("");
4757 break;
4758
4759 case SPVFuncImplDegrees:
4760 statement("// Implementation of the GLSL degrees() function");
4761 statement("template<typename T>");
4762 statement("inline T degrees(T r)");
4763 begin_scope();
4764 statement("return r * T(57.2957795131);");
4765 end_scope();
4766 statement("");
4767 break;
4768
4769 case SPVFuncImplFindILsb:
4770 statement("// Implementation of the GLSL findLSB() function");
4771 statement("template<typename T>");
4772 statement("inline T spvFindLSB(T x)");
4773 begin_scope();
4774 statement("return select(ctz(x), T(-1), x == T(0));");
4775 end_scope();
4776 statement("");
4777 break;
4778
4779 case SPVFuncImplFindUMsb:
4780 statement("// Implementation of the unsigned GLSL findMSB() function");
4781 statement("template<typename T>");
4782 statement("inline T spvFindUMSB(T x)");
4783 begin_scope();
4784 statement("return select(clz(T(0)) - (clz(x) + T(1)), T(-1), x == T(0));");
4785 end_scope();
4786 statement("");
4787 break;
4788
4789 case SPVFuncImplFindSMsb:
4790 statement("// Implementation of the signed GLSL findMSB() function");
4791 statement("template<typename T>");
4792 statement("inline T spvFindSMSB(T x)");
4793 begin_scope();
4794 statement("T v = select(x, T(-1) - x, x < T(0));");
4795 statement("return select(clz(T(0)) - (clz(v) + T(1)), T(-1), v == T(0));");
4796 end_scope();
4797 statement("");
4798 break;
4799
4800 case SPVFuncImplSSign:
4801 statement("// Implementation of the GLSL sign() function for integer types");
4802 statement("template<typename T, typename E = typename enable_if<is_integral<T>::value>::type>");
4803 statement("inline T sign(T x)");
4804 begin_scope();
4805 statement("return select(select(select(x, T(0), x == T(0)), T(1), x > T(0)), T(-1), x < T(0));");
4806 end_scope();
4807 statement("");
4808 break;
4809
4810 case SPVFuncImplArrayCopy:
4811 case SPVFuncImplArrayOfArrayCopy2Dim:
4812 case SPVFuncImplArrayOfArrayCopy3Dim:
4813 case SPVFuncImplArrayOfArrayCopy4Dim:
4814 case SPVFuncImplArrayOfArrayCopy5Dim:
4815 case SPVFuncImplArrayOfArrayCopy6Dim:
4816 {
4817 // Unfortunately we cannot template on the address space, so combinatorial explosion it is.
4818 static const char *function_name_tags[] = {
4819 "FromConstantToStack", "FromConstantToThreadGroup", "FromStackToStack",
4820 "FromStackToThreadGroup", "FromThreadGroupToStack", "FromThreadGroupToThreadGroup",
4821 "FromDeviceToDevice", "FromConstantToDevice", "FromStackToDevice",
4822 "FromThreadGroupToDevice", "FromDeviceToStack", "FromDeviceToThreadGroup",
4823 };
4824
4825 static const char *src_address_space[] = {
4826 "constant", "constant", "thread const", "thread const",
4827 "threadgroup const", "threadgroup const", "device const", "constant",
4828 "thread const", "threadgroup const", "device const", "device const",
4829 };
4830
4831 static const char *dst_address_space[] = {
4832 "thread", "threadgroup", "thread", "threadgroup", "thread", "threadgroup",
4833 "device", "device", "device", "device", "thread", "threadgroup",
4834 };
4835
4836 for (uint32_t variant = 0; variant < 12; variant++)
4837 {
4838 uint32_t dimensions = spv_func - SPVFuncImplArrayCopyMultidimBase;
4839 string tmp = "template<typename T";
4840 for (uint8_t i = 0; i < dimensions; i++)
4841 {
4842 tmp += ", uint ";
4843 tmp += 'A' + i;
4844 }
4845 tmp += ">";
4846 statement(tmp);
4847
4848 string array_arg;
4849 for (uint8_t i = 0; i < dimensions; i++)
4850 {
4851 array_arg += "[";
4852 array_arg += 'A' + i;
4853 array_arg += "]";
4854 }
4855
4856 statement("inline void spvArrayCopy", function_name_tags[variant], dimensions, "(",
4857 dst_address_space[variant], " T (&dst)", array_arg, ", ", src_address_space[variant],
4858 " T (&src)", array_arg, ")");
4859
4860 begin_scope();
4861 statement("for (uint i = 0; i < A; i++)");
4862 begin_scope();
4863
4864 if (dimensions == 1)
4865 statement("dst[i] = src[i];");
4866 else
4867 statement("spvArrayCopy", function_name_tags[variant], dimensions - 1, "(dst[i], src[i]);");
4868 end_scope();
4869 end_scope();
4870 statement("");
4871 }
4872 break;
4873 }
4874
4875 // Support for Metal 2.1's new texture_buffer type.
4876 case SPVFuncImplTexelBufferCoords:
4877 {
4878 if (msl_options.texel_buffer_texture_width > 0)
4879 {
4880 string tex_width_str = convert_to_string(msl_options.texel_buffer_texture_width);
4881 statement("// Returns 2D texture coords corresponding to 1D texel buffer coords");
4882 statement(force_inline);
4883 statement("uint2 spvTexelBufferCoord(uint tc)");
4884 begin_scope();
4885 statement(join("return uint2(tc % ", tex_width_str, ", tc / ", tex_width_str, ");"));
4886 end_scope();
4887 statement("");
4888 }
4889 else
4890 {
4891 statement("// Returns 2D texture coords corresponding to 1D texel buffer coords");
4892 statement(
4893 "#define spvTexelBufferCoord(tc, tex) uint2((tc) % (tex).get_width(), (tc) / (tex).get_width())");
4894 statement("");
4895 }
4896 break;
4897 }
4898
4899 // Emulate texture2D atomic operations
4900 case SPVFuncImplImage2DAtomicCoords:
4901 {
4902 if (msl_options.supports_msl_version(1, 2))
4903 {
4904 statement("// The required alignment of a linear texture of R32Uint format.");
4905 statement("constant uint spvLinearTextureAlignmentOverride [[function_constant(",
4906 msl_options.r32ui_alignment_constant_id, ")]];");
4907 statement("constant uint spvLinearTextureAlignment = ",
4908 "is_function_constant_defined(spvLinearTextureAlignmentOverride) ? ",
4909 "spvLinearTextureAlignmentOverride : ", msl_options.r32ui_linear_texture_alignment, ";");
4910 }
4911 else
4912 {
4913 statement("// The required alignment of a linear texture of R32Uint format.");
4914 statement("constant uint spvLinearTextureAlignment = ", msl_options.r32ui_linear_texture_alignment,
4915 ";");
4916 }
4917 statement("// Returns buffer coords corresponding to 2D texture coords for emulating 2D texture atomics");
4918 statement("#define spvImage2DAtomicCoord(tc, tex) (((((tex).get_width() + ",
4919 " spvLinearTextureAlignment / 4 - 1) & ~(",
4920 " spvLinearTextureAlignment / 4 - 1)) * (tc).y) + (tc).x)");
4921 statement("");
4922 break;
4923 }
4924
4925 // "fadd" intrinsic support
4926 case SPVFuncImplFAdd:
4927 statement("template<typename T>");
4928 statement("T spvFAdd(T l, T r)");
4929 begin_scope();
4930 statement("return fma(T(1), l, r);");
4931 end_scope();
4932 statement("");
4933 break;
4934
4935 // "fsub" intrinsic support
4936 case SPVFuncImplFSub:
4937 statement("template<typename T>");
4938 statement("T spvFSub(T l, T r)");
4939 begin_scope();
4940 statement("return fma(T(-1), r, l);");
4941 end_scope();
4942 statement("");
4943 break;
4944
4945 // "fmul' intrinsic support
4946 case SPVFuncImplFMul:
4947 statement("template<typename T>");
4948 statement("T spvFMul(T l, T r)");
4949 begin_scope();
4950 statement("return fma(l, r, T(0));");
4951 end_scope();
4952 statement("");
4953
4954 statement("template<typename T, int Cols, int Rows>");
4955 statement("vec<T, Cols> spvFMulVectorMatrix(vec<T, Rows> v, matrix<T, Cols, Rows> m)");
4956 begin_scope();
4957 statement("vec<T, Cols> res = vec<T, Cols>(0);");
4958 statement("for (uint i = Rows; i > 0; --i)");
4959 begin_scope();
4960 statement("vec<T, Cols> tmp(0);");
4961 statement("for (uint j = 0; j < Cols; ++j)");
4962 begin_scope();
4963 statement("tmp[j] = m[j][i - 1];");
4964 end_scope();
4965 statement("res = fma(tmp, vec<T, Cols>(v[i - 1]), res);");
4966 end_scope();
4967 statement("return res;");
4968 end_scope();
4969 statement("");
4970
4971 statement("template<typename T, int Cols, int Rows>");
4972 statement("vec<T, Rows> spvFMulMatrixVector(matrix<T, Cols, Rows> m, vec<T, Cols> v)");
4973 begin_scope();
4974 statement("vec<T, Rows> res = vec<T, Rows>(0);");
4975 statement("for (uint i = Cols; i > 0; --i)");
4976 begin_scope();
4977 statement("res = fma(m[i - 1], vec<T, Rows>(v[i - 1]), res);");
4978 end_scope();
4979 statement("return res;");
4980 end_scope();
4981 statement("");
4982
4983 statement("template<typename T, int LCols, int LRows, int RCols, int RRows>");
4984 statement(
4985 "matrix<T, RCols, LRows> spvFMulMatrixMatrix(matrix<T, LCols, LRows> l, matrix<T, RCols, RRows> r)");
4986 begin_scope();
4987 statement("matrix<T, RCols, LRows> res;");
4988 statement("for (uint i = 0; i < RCols; i++)");
4989 begin_scope();
4990 statement("vec<T, RCols> tmp(0);");
4991 statement("for (uint j = 0; j < LCols; j++)");
4992 begin_scope();
4993 statement("tmp = fma(vec<T, RCols>(r[i][j]), l[j], tmp);");
4994 end_scope();
4995 statement("res[i] = tmp;");
4996 end_scope();
4997 statement("return res;");
4998 end_scope();
4999 statement("");
5000 break;
5001
5002 // Emulate texturecube_array with texture2d_array for iOS where this type is not available
5003 case SPVFuncImplCubemapTo2DArrayFace:
5004 statement(force_inline);
5005 statement("float3 spvCubemapTo2DArrayFace(float3 P)");
5006 begin_scope();
5007 statement("float3 Coords = abs(P.xyz);");
5008 statement("float CubeFace = 0;");
5009 statement("float ProjectionAxis = 0;");
5010 statement("float u = 0;");
5011 statement("float v = 0;");
5012 statement("if (Coords.x >= Coords.y && Coords.x >= Coords.z)");
5013 begin_scope();
5014 statement("CubeFace = P.x >= 0 ? 0 : 1;");
5015 statement("ProjectionAxis = Coords.x;");
5016 statement("u = P.x >= 0 ? -P.z : P.z;");
5017 statement("v = -P.y;");
5018 end_scope();
5019 statement("else if (Coords.y >= Coords.x && Coords.y >= Coords.z)");
5020 begin_scope();
5021 statement("CubeFace = P.y >= 0 ? 2 : 3;");
5022 statement("ProjectionAxis = Coords.y;");
5023 statement("u = P.x;");
5024 statement("v = P.y >= 0 ? P.z : -P.z;");
5025 end_scope();
5026 statement("else");
5027 begin_scope();
5028 statement("CubeFace = P.z >= 0 ? 4 : 5;");
5029 statement("ProjectionAxis = Coords.z;");
5030 statement("u = P.z >= 0 ? P.x : -P.x;");
5031 statement("v = -P.y;");
5032 end_scope();
5033 statement("u = 0.5 * (u/ProjectionAxis + 1);");
5034 statement("v = 0.5 * (v/ProjectionAxis + 1);");
5035 statement("return float3(u, v, CubeFace);");
5036 end_scope();
5037 statement("");
5038 break;
5039
5040 case SPVFuncImplInverse4x4:
5041 statement("// Returns the determinant of a 2x2 matrix.");
5042 statement(force_inline);
5043 statement("float spvDet2x2(float a1, float a2, float b1, float b2)");
5044 begin_scope();
5045 statement("return a1 * b2 - b1 * a2;");
5046 end_scope();
5047 statement("");
5048
5049 statement("// Returns the determinant of a 3x3 matrix.");
5050 statement(force_inline);
5051 statement("float spvDet3x3(float a1, float a2, float a3, float b1, float b2, float b3, float c1, "
5052 "float c2, float c3)");
5053 begin_scope();
5054 statement("return a1 * spvDet2x2(b2, b3, c2, c3) - b1 * spvDet2x2(a2, a3, c2, c3) + c1 * spvDet2x2(a2, a3, "
5055 "b2, b3);");
5056 end_scope();
5057 statement("");
5058 statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
5059 statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
5060 statement(force_inline);
5061 statement("float4x4 spvInverse4x4(float4x4 m)");
5062 begin_scope();
5063 statement("float4x4 adj; // The adjoint matrix (inverse after dividing by determinant)");
5064 statement_no_indent("");
5065 statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
5066 statement("adj[0][0] = spvDet3x3(m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], "
5067 "m[3][3]);");
5068 statement("adj[0][1] = -spvDet3x3(m[0][1], m[0][2], m[0][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], "
5069 "m[3][3]);");
5070 statement("adj[0][2] = spvDet3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[3][1], m[3][2], "
5071 "m[3][3]);");
5072 statement("adj[0][3] = -spvDet3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], "
5073 "m[2][3]);");
5074 statement_no_indent("");
5075 statement("adj[1][0] = -spvDet3x3(m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], "
5076 "m[3][3]);");
5077 statement("adj[1][1] = spvDet3x3(m[0][0], m[0][2], m[0][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], "
5078 "m[3][3]);");
5079 statement("adj[1][2] = -spvDet3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[3][0], m[3][2], "
5080 "m[3][3]);");
5081 statement("adj[1][3] = spvDet3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], "
5082 "m[2][3]);");
5083 statement_no_indent("");
5084 statement("adj[2][0] = spvDet3x3(m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], "
5085 "m[3][3]);");
5086 statement("adj[2][1] = -spvDet3x3(m[0][0], m[0][1], m[0][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], "
5087 "m[3][3]);");
5088 statement("adj[2][2] = spvDet3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[3][0], m[3][1], "
5089 "m[3][3]);");
5090 statement("adj[2][3] = -spvDet3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], "
5091 "m[2][3]);");
5092 statement_no_indent("");
5093 statement("adj[3][0] = -spvDet3x3(m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], "
5094 "m[3][2]);");
5095 statement("adj[3][1] = spvDet3x3(m[0][0], m[0][1], m[0][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], "
5096 "m[3][2]);");
5097 statement("adj[3][2] = -spvDet3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[3][0], m[3][1], "
5098 "m[3][2]);");
5099 statement("adj[3][3] = spvDet3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], "
5100 "m[2][2]);");
5101 statement_no_indent("");
5102 statement("// Calculate the determinant as a combination of the cofactors of the first row.");
5103 statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]) + (adj[0][3] "
5104 "* m[3][0]);");
5105 statement_no_indent("");
5106 statement("// Divide the classical adjoint matrix by the determinant.");
5107 statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
5108 statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
5109 end_scope();
5110 statement("");
5111 break;
5112
5113 case SPVFuncImplInverse3x3:
5114 if (spv_function_implementations.count(SPVFuncImplInverse4x4) == 0)
5115 {
5116 statement("// Returns the determinant of a 2x2 matrix.");
5117 statement(force_inline);
5118 statement("float spvDet2x2(float a1, float a2, float b1, float b2)");
5119 begin_scope();
5120 statement("return a1 * b2 - b1 * a2;");
5121 end_scope();
5122 statement("");
5123 }
5124
5125 statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
5126 statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
5127 statement(force_inline);
5128 statement("float3x3 spvInverse3x3(float3x3 m)");
5129 begin_scope();
5130 statement("float3x3 adj; // The adjoint matrix (inverse after dividing by determinant)");
5131 statement_no_indent("");
5132 statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
5133 statement("adj[0][0] = spvDet2x2(m[1][1], m[1][2], m[2][1], m[2][2]);");
5134 statement("adj[0][1] = -spvDet2x2(m[0][1], m[0][2], m[2][1], m[2][2]);");
5135 statement("adj[0][2] = spvDet2x2(m[0][1], m[0][2], m[1][1], m[1][2]);");
5136 statement_no_indent("");
5137 statement("adj[1][0] = -spvDet2x2(m[1][0], m[1][2], m[2][0], m[2][2]);");
5138 statement("adj[1][1] = spvDet2x2(m[0][0], m[0][2], m[2][0], m[2][2]);");
5139 statement("adj[1][2] = -spvDet2x2(m[0][0], m[0][2], m[1][0], m[1][2]);");
5140 statement_no_indent("");
5141 statement("adj[2][0] = spvDet2x2(m[1][0], m[1][1], m[2][0], m[2][1]);");
5142 statement("adj[2][1] = -spvDet2x2(m[0][0], m[0][1], m[2][0], m[2][1]);");
5143 statement("adj[2][2] = spvDet2x2(m[0][0], m[0][1], m[1][0], m[1][1]);");
5144 statement_no_indent("");
5145 statement("// Calculate the determinant as a combination of the cofactors of the first row.");
5146 statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]);");
5147 statement_no_indent("");
5148 statement("// Divide the classical adjoint matrix by the determinant.");
5149 statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
5150 statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
5151 end_scope();
5152 statement("");
5153 break;
5154
5155 case SPVFuncImplInverse2x2:
5156 statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
5157 statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
5158 statement(force_inline);
5159 statement("float2x2 spvInverse2x2(float2x2 m)");
5160 begin_scope();
5161 statement("float2x2 adj; // The adjoint matrix (inverse after dividing by determinant)");
5162 statement_no_indent("");
5163 statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
5164 statement("adj[0][0] = m[1][1];");
5165 statement("adj[0][1] = -m[0][1];");
5166 statement_no_indent("");
5167 statement("adj[1][0] = -m[1][0];");
5168 statement("adj[1][1] = m[0][0];");
5169 statement_no_indent("");
5170 statement("// Calculate the determinant as a combination of the cofactors of the first row.");
5171 statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]);");
5172 statement_no_indent("");
5173 statement("// Divide the classical adjoint matrix by the determinant.");
5174 statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
5175 statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
5176 end_scope();
5177 statement("");
5178 break;
5179
5180 case SPVFuncImplForwardArgs:
5181 statement("template<typename T> struct spvRemoveReference { typedef T type; };");
5182 statement("template<typename T> struct spvRemoveReference<thread T&> { typedef T type; };");
5183 statement("template<typename T> struct spvRemoveReference<thread T&&> { typedef T type; };");
5184 statement("template<typename T> inline constexpr thread T&& spvForward(thread typename "
5185 "spvRemoveReference<T>::type& x)");
5186 begin_scope();
5187 statement("return static_cast<thread T&&>(x);");
5188 end_scope();
5189 statement("template<typename T> inline constexpr thread T&& spvForward(thread typename "
5190 "spvRemoveReference<T>::type&& x)");
5191 begin_scope();
5192 statement("return static_cast<thread T&&>(x);");
5193 end_scope();
5194 statement("");
5195 break;
5196
5197 case SPVFuncImplGetSwizzle:
5198 statement("enum class spvSwizzle : uint");
5199 begin_scope();
5200 statement("none = 0,");
5201 statement("zero,");
5202 statement("one,");
5203 statement("red,");
5204 statement("green,");
5205 statement("blue,");
5206 statement("alpha");
5207 end_scope_decl();
5208 statement("");
5209 statement("template<typename T>");
5210 statement("inline T spvGetSwizzle(vec<T, 4> x, T c, spvSwizzle s)");
5211 begin_scope();
5212 statement("switch (s)");
5213 begin_scope();
5214 statement("case spvSwizzle::none:");
5215 statement(" return c;");
5216 statement("case spvSwizzle::zero:");
5217 statement(" return 0;");
5218 statement("case spvSwizzle::one:");
5219 statement(" return 1;");
5220 statement("case spvSwizzle::red:");
5221 statement(" return x.r;");
5222 statement("case spvSwizzle::green:");
5223 statement(" return x.g;");
5224 statement("case spvSwizzle::blue:");
5225 statement(" return x.b;");
5226 statement("case spvSwizzle::alpha:");
5227 statement(" return x.a;");
5228 end_scope();
5229 end_scope();
5230 statement("");
5231 break;
5232
5233 case SPVFuncImplTextureSwizzle:
5234 statement("// Wrapper function that swizzles texture samples and fetches.");
5235 statement("template<typename T>");
5236 statement("inline vec<T, 4> spvTextureSwizzle(vec<T, 4> x, uint s)");
5237 begin_scope();
5238 statement("if (!s)");
5239 statement(" return x;");
5240 statement("return vec<T, 4>(spvGetSwizzle(x, x.r, spvSwizzle((s >> 0) & 0xFF)), "
5241 "spvGetSwizzle(x, x.g, spvSwizzle((s >> 8) & 0xFF)), spvGetSwizzle(x, x.b, spvSwizzle((s >> 16) "
5242 "& 0xFF)), "
5243 "spvGetSwizzle(x, x.a, spvSwizzle((s >> 24) & 0xFF)));");
5244 end_scope();
5245 statement("");
5246 statement("template<typename T>");
5247 statement("inline T spvTextureSwizzle(T x, uint s)");
5248 begin_scope();
5249 statement("return spvTextureSwizzle(vec<T, 4>(x, 0, 0, 1), s).x;");
5250 end_scope();
5251 statement("");
5252 break;
5253
5254 case SPVFuncImplGatherSwizzle:
5255 statement("// Wrapper function that swizzles texture gathers.");
5256 statement("template<typename T, template<typename, access = access::sample, typename = void> class Tex, "
5257 "typename... Ts>");
5258 statement("inline vec<T, 4> spvGatherSwizzle(const thread Tex<T>& t, sampler s, "
5259 "uint sw, component c, Ts... params) METAL_CONST_ARG(c)");
5260 begin_scope();
5261 statement("if (sw)");
5262 begin_scope();
5263 statement("switch (spvSwizzle((sw >> (uint(c) * 8)) & 0xFF))");
5264 begin_scope();
5265 statement("case spvSwizzle::none:");
5266 statement(" break;");
5267 statement("case spvSwizzle::zero:");
5268 statement(" return vec<T, 4>(0, 0, 0, 0);");
5269 statement("case spvSwizzle::one:");
5270 statement(" return vec<T, 4>(1, 1, 1, 1);");
5271 statement("case spvSwizzle::red:");
5272 statement(" return t.gather(s, spvForward<Ts>(params)..., component::x);");
5273 statement("case spvSwizzle::green:");
5274 statement(" return t.gather(s, spvForward<Ts>(params)..., component::y);");
5275 statement("case spvSwizzle::blue:");
5276 statement(" return t.gather(s, spvForward<Ts>(params)..., component::z);");
5277 statement("case spvSwizzle::alpha:");
5278 statement(" return t.gather(s, spvForward<Ts>(params)..., component::w);");
5279 end_scope();
5280 end_scope();
5281 // texture::gather insists on its component parameter being a constant
5282 // expression, so we need this silly workaround just to compile the shader.
5283 statement("switch (c)");
5284 begin_scope();
5285 statement("case component::x:");
5286 statement(" return t.gather(s, spvForward<Ts>(params)..., component::x);");
5287 statement("case component::y:");
5288 statement(" return t.gather(s, spvForward<Ts>(params)..., component::y);");
5289 statement("case component::z:");
5290 statement(" return t.gather(s, spvForward<Ts>(params)..., component::z);");
5291 statement("case component::w:");
5292 statement(" return t.gather(s, spvForward<Ts>(params)..., component::w);");
5293 end_scope();
5294 end_scope();
5295 statement("");
5296 break;
5297
5298 case SPVFuncImplGatherCompareSwizzle:
5299 statement("// Wrapper function that swizzles depth texture gathers.");
5300 statement("template<typename T, template<typename, access = access::sample, typename = void> class Tex, "
5301 "typename... Ts>");
5302 statement("inline vec<T, 4> spvGatherCompareSwizzle(const thread Tex<T>& t, sampler "
5303 "s, uint sw, Ts... params) ");
5304 begin_scope();
5305 statement("if (sw)");
5306 begin_scope();
5307 statement("switch (spvSwizzle(sw & 0xFF))");
5308 begin_scope();
5309 statement("case spvSwizzle::none:");
5310 statement("case spvSwizzle::red:");
5311 statement(" break;");
5312 statement("case spvSwizzle::zero:");
5313 statement("case spvSwizzle::green:");
5314 statement("case spvSwizzle::blue:");
5315 statement("case spvSwizzle::alpha:");
5316 statement(" return vec<T, 4>(0, 0, 0, 0);");
5317 statement("case spvSwizzle::one:");
5318 statement(" return vec<T, 4>(1, 1, 1, 1);");
5319 end_scope();
5320 end_scope();
5321 statement("return t.gather_compare(s, spvForward<Ts>(params)...);");
5322 end_scope();
5323 statement("");
5324 break;
5325
5326 case SPVFuncImplSubgroupBroadcast:
5327 // Metal doesn't allow broadcasting boolean values directly, but we can work around that by broadcasting
5328 // them as integers.
5329 statement("template<typename T>");
5330 statement("inline T spvSubgroupBroadcast(T value, ushort lane)");
5331 begin_scope();
5332 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5333 statement("return quad_broadcast(value, lane);");
5334 else
5335 statement("return simd_broadcast(value, lane);");
5336 end_scope();
5337 statement("");
5338 statement("template<>");
5339 statement("inline bool spvSubgroupBroadcast(bool value, ushort lane)");
5340 begin_scope();
5341 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5342 statement("return !!quad_broadcast((ushort)value, lane);");
5343 else
5344 statement("return !!simd_broadcast((ushort)value, lane);");
5345 end_scope();
5346 statement("");
5347 statement("template<uint N>");
5348 statement("inline vec<bool, N> spvSubgroupBroadcast(vec<bool, N> value, ushort lane)");
5349 begin_scope();
5350 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5351 statement("return (vec<bool, N>)quad_broadcast((vec<ushort, N>)value, lane);");
5352 else
5353 statement("return (vec<bool, N>)simd_broadcast((vec<ushort, N>)value, lane);");
5354 end_scope();
5355 statement("");
5356 break;
5357
5358 case SPVFuncImplSubgroupBroadcastFirst:
5359 statement("template<typename T>");
5360 statement("inline T spvSubgroupBroadcastFirst(T value)");
5361 begin_scope();
5362 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5363 statement("return quad_broadcast_first(value);");
5364 else
5365 statement("return simd_broadcast_first(value);");
5366 end_scope();
5367 statement("");
5368 statement("template<>");
5369 statement("inline bool spvSubgroupBroadcastFirst(bool value)");
5370 begin_scope();
5371 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5372 statement("return !!quad_broadcast_first((ushort)value);");
5373 else
5374 statement("return !!simd_broadcast_first((ushort)value);");
5375 end_scope();
5376 statement("");
5377 statement("template<uint N>");
5378 statement("inline vec<bool, N> spvSubgroupBroadcastFirst(vec<bool, N> value)");
5379 begin_scope();
5380 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5381 statement("return (vec<bool, N>)quad_broadcast_first((vec<ushort, N>)value);");
5382 else
5383 statement("return (vec<bool, N>)simd_broadcast_first((vec<ushort, N>)value);");
5384 end_scope();
5385 statement("");
5386 break;
5387
5388 case SPVFuncImplSubgroupBallot:
5389 statement("inline uint4 spvSubgroupBallot(bool value)");
5390 begin_scope();
5391 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5392 {
5393 statement("return uint4((quad_vote::vote_t)quad_ballot(value), 0, 0, 0);");
5394 }
5395 else if (msl_options.is_ios())
5396 {
5397 // The current simd_vote on iOS uses a 32-bit integer-like object.
5398 statement("return uint4((simd_vote::vote_t)simd_ballot(value), 0, 0, 0);");
5399 }
5400 else
5401 {
5402 statement("simd_vote vote = simd_ballot(value);");
5403 statement("// simd_ballot() returns a 64-bit integer-like object, but");
5404 statement("// SPIR-V callers expect a uint4. We must convert.");
5405 statement("// FIXME: This won't include higher bits if Apple ever supports");
5406 statement("// 128 lanes in an SIMD-group.");
5407 statement("return uint4(as_type<uint2>((simd_vote::vote_t)vote), 0, 0);");
5408 }
5409 end_scope();
5410 statement("");
5411 break;
5412
5413 case SPVFuncImplSubgroupBallotBitExtract:
5414 statement("inline bool spvSubgroupBallotBitExtract(uint4 ballot, uint bit)");
5415 begin_scope();
5416 statement("return !!extract_bits(ballot[bit / 32], bit % 32, 1);");
5417 end_scope();
5418 statement("");
5419 break;
5420
5421 case SPVFuncImplSubgroupBallotFindLSB:
5422 statement("inline uint spvSubgroupBallotFindLSB(uint4 ballot, uint gl_SubgroupSize)");
5423 begin_scope();
5424 if (msl_options.is_ios())
5425 {
5426 statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupSize), uint3(0));");
5427 }
5428 else
5429 {
5430 statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), "
5431 "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));");
5432 }
5433 statement("ballot &= mask;");
5434 statement("return select(ctz(ballot.x), select(32 + ctz(ballot.y), select(64 + ctz(ballot.z), select(96 + "
5435 "ctz(ballot.w), uint(-1), ballot.w == 0), ballot.z == 0), ballot.y == 0), ballot.x == 0);");
5436 end_scope();
5437 statement("");
5438 break;
5439
5440 case SPVFuncImplSubgroupBallotFindMSB:
5441 statement("inline uint spvSubgroupBallotFindMSB(uint4 ballot, uint gl_SubgroupSize)");
5442 begin_scope();
5443 if (msl_options.is_ios())
5444 {
5445 statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupSize), uint3(0));");
5446 }
5447 else
5448 {
5449 statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), "
5450 "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));");
5451 }
5452 statement("ballot &= mask;");
5453 statement("return select(128 - (clz(ballot.w) + 1), select(96 - (clz(ballot.z) + 1), select(64 - "
5454 "(clz(ballot.y) + 1), select(32 - (clz(ballot.x) + 1), uint(-1), ballot.x == 0), ballot.y == 0), "
5455 "ballot.z == 0), ballot.w == 0);");
5456 end_scope();
5457 statement("");
5458 break;
5459
5460 case SPVFuncImplSubgroupBallotBitCount:
5461 statement("inline uint spvPopCount4(uint4 ballot)");
5462 begin_scope();
5463 statement("return popcount(ballot.x) + popcount(ballot.y) + popcount(ballot.z) + popcount(ballot.w);");
5464 end_scope();
5465 statement("");
5466 statement("inline uint spvSubgroupBallotBitCount(uint4 ballot, uint gl_SubgroupSize)");
5467 begin_scope();
5468 if (msl_options.is_ios())
5469 {
5470 statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupSize), uint3(0));");
5471 }
5472 else
5473 {
5474 statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), "
5475 "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));");
5476 }
5477 statement("return spvPopCount4(ballot & mask);");
5478 end_scope();
5479 statement("");
5480 statement("inline uint spvSubgroupBallotInclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)");
5481 begin_scope();
5482 if (msl_options.is_ios())
5483 {
5484 statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupInvocationID + 1), uint3(0));");
5485 }
5486 else
5487 {
5488 statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID + 1, 32u)), "
5489 "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0)), "
5490 "uint2(0));");
5491 }
5492 statement("return spvPopCount4(ballot & mask);");
5493 end_scope();
5494 statement("");
5495 statement("inline uint spvSubgroupBallotExclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)");
5496 begin_scope();
5497 if (msl_options.is_ios())
5498 {
5499 statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupInvocationID), uint2(0));");
5500 }
5501 else
5502 {
5503 statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID, 32u)), "
5504 "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID - 32, 0)), uint2(0));");
5505 }
5506 statement("return spvPopCount4(ballot & mask);");
5507 end_scope();
5508 statement("");
5509 break;
5510
5511 case SPVFuncImplSubgroupAllEqual:
5512 // Metal doesn't provide a function to evaluate this directly. But, we can
5513 // implement this by comparing every thread's value to one thread's value
5514 // (in this case, the value of the first active thread). Then, by the transitive
5515 // property of equality, if all comparisons return true, then they are all equal.
5516 statement("template<typename T>");
5517 statement("inline bool spvSubgroupAllEqual(T value)");
5518 begin_scope();
5519 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5520 statement("return quad_all(all(value == quad_broadcast_first(value)));");
5521 else
5522 statement("return simd_all(all(value == simd_broadcast_first(value)));");
5523 end_scope();
5524 statement("");
5525 statement("template<>");
5526 statement("inline bool spvSubgroupAllEqual(bool value)");
5527 begin_scope();
5528 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5529 statement("return quad_all(value) || !quad_any(value);");
5530 else
5531 statement("return simd_all(value) || !simd_any(value);");
5532 end_scope();
5533 statement("");
5534 statement("template<uint N>");
5535 statement("inline bool spvSubgroupAllEqual(vec<bool, N> value)");
5536 begin_scope();
5537 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5538 statement("return quad_all(all(value == (vec<bool, N>)quad_broadcast_first((vec<ushort, N>)value)));");
5539 else
5540 statement("return simd_all(all(value == (vec<bool, N>)simd_broadcast_first((vec<ushort, N>)value)));");
5541 end_scope();
5542 statement("");
5543 break;
5544
5545 case SPVFuncImplSubgroupShuffle:
5546 statement("template<typename T>");
5547 statement("inline T spvSubgroupShuffle(T value, ushort lane)");
5548 begin_scope();
5549 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5550 statement("return quad_shuffle(value, lane);");
5551 else
5552 statement("return simd_shuffle(value, lane);");
5553 end_scope();
5554 statement("");
5555 statement("template<>");
5556 statement("inline bool spvSubgroupShuffle(bool value, ushort lane)");
5557 begin_scope();
5558 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5559 statement("return !!quad_shuffle((ushort)value, lane);");
5560 else
5561 statement("return !!simd_shuffle((ushort)value, lane);");
5562 end_scope();
5563 statement("");
5564 statement("template<uint N>");
5565 statement("inline vec<bool, N> spvSubgroupShuffle(vec<bool, N> value, ushort lane)");
5566 begin_scope();
5567 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5568 statement("return (vec<bool, N>)quad_shuffle((vec<ushort, N>)value, lane);");
5569 else
5570 statement("return (vec<bool, N>)simd_shuffle((vec<ushort, N>)value, lane);");
5571 end_scope();
5572 statement("");
5573 break;
5574
5575 case SPVFuncImplSubgroupShuffleXor:
5576 statement("template<typename T>");
5577 statement("inline T spvSubgroupShuffleXor(T value, ushort mask)");
5578 begin_scope();
5579 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5580 statement("return quad_shuffle_xor(value, mask);");
5581 else
5582 statement("return simd_shuffle_xor(value, mask);");
5583 end_scope();
5584 statement("");
5585 statement("template<>");
5586 statement("inline bool spvSubgroupShuffleXor(bool value, ushort mask)");
5587 begin_scope();
5588 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5589 statement("return !!quad_shuffle_xor((ushort)value, mask);");
5590 else
5591 statement("return !!simd_shuffle_xor((ushort)value, mask);");
5592 end_scope();
5593 statement("");
5594 statement("template<uint N>");
5595 statement("inline vec<bool, N> spvSubgroupShuffleXor(vec<bool, N> value, ushort mask)");
5596 begin_scope();
5597 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5598 statement("return (vec<bool, N>)quad_shuffle_xor((vec<ushort, N>)value, mask);");
5599 else
5600 statement("return (vec<bool, N>)simd_shuffle_xor((vec<ushort, N>)value, mask);");
5601 end_scope();
5602 statement("");
5603 break;
5604
5605 case SPVFuncImplSubgroupShuffleUp:
5606 statement("template<typename T>");
5607 statement("inline T spvSubgroupShuffleUp(T value, ushort delta)");
5608 begin_scope();
5609 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5610 statement("return quad_shuffle_up(value, delta);");
5611 else
5612 statement("return simd_shuffle_up(value, delta);");
5613 end_scope();
5614 statement("");
5615 statement("template<>");
5616 statement("inline bool spvSubgroupShuffleUp(bool value, ushort delta)");
5617 begin_scope();
5618 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5619 statement("return !!quad_shuffle_up((ushort)value, delta);");
5620 else
5621 statement("return !!simd_shuffle_up((ushort)value, delta);");
5622 end_scope();
5623 statement("");
5624 statement("template<uint N>");
5625 statement("inline vec<bool, N> spvSubgroupShuffleUp(vec<bool, N> value, ushort delta)");
5626 begin_scope();
5627 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5628 statement("return (vec<bool, N>)quad_shuffle_up((vec<ushort, N>)value, delta);");
5629 else
5630 statement("return (vec<bool, N>)simd_shuffle_up((vec<ushort, N>)value, delta);");
5631 end_scope();
5632 statement("");
5633 break;
5634
5635 case SPVFuncImplSubgroupShuffleDown:
5636 statement("template<typename T>");
5637 statement("inline T spvSubgroupShuffleDown(T value, ushort delta)");
5638 begin_scope();
5639 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5640 statement("return quad_shuffle_down(value, delta);");
5641 else
5642 statement("return simd_shuffle_down(value, delta);");
5643 end_scope();
5644 statement("");
5645 statement("template<>");
5646 statement("inline bool spvSubgroupShuffleDown(bool value, ushort delta)");
5647 begin_scope();
5648 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5649 statement("return !!quad_shuffle_down((ushort)value, delta);");
5650 else
5651 statement("return !!simd_shuffle_down((ushort)value, delta);");
5652 end_scope();
5653 statement("");
5654 statement("template<uint N>");
5655 statement("inline vec<bool, N> spvSubgroupShuffleDown(vec<bool, N> value, ushort delta)");
5656 begin_scope();
5657 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5658 statement("return (vec<bool, N>)quad_shuffle_down((vec<ushort, N>)value, delta);");
5659 else
5660 statement("return (vec<bool, N>)simd_shuffle_down((vec<ushort, N>)value, delta);");
5661 end_scope();
5662 statement("");
5663 break;
5664
5665 case SPVFuncImplQuadBroadcast:
5666 statement("template<typename T>");
5667 statement("inline T spvQuadBroadcast(T value, uint lane)");
5668 begin_scope();
5669 statement("return quad_broadcast(value, lane);");
5670 end_scope();
5671 statement("");
5672 statement("template<>");
5673 statement("inline bool spvQuadBroadcast(bool value, uint lane)");
5674 begin_scope();
5675 statement("return !!quad_broadcast((ushort)value, lane);");
5676 end_scope();
5677 statement("");
5678 statement("template<uint N>");
5679 statement("inline vec<bool, N> spvQuadBroadcast(vec<bool, N> value, uint lane)");
5680 begin_scope();
5681 statement("return (vec<bool, N>)quad_broadcast((vec<ushort, N>)value, lane);");
5682 end_scope();
5683 statement("");
5684 break;
5685
5686 case SPVFuncImplQuadSwap:
5687 // We can implement this easily based on the following table giving
5688 // the target lane ID from the direction and current lane ID:
5689 // Direction
5690 // | 0 | 1 | 2 |
5691 // ---+---+---+---+
5692 // L 0 | 1 2 3
5693 // a 1 | 0 3 2
5694 // n 2 | 3 0 1
5695 // e 3 | 2 1 0
5696 // Notice that target = source ^ (direction + 1).
5697 statement("template<typename T>");
5698 statement("inline T spvQuadSwap(T value, uint dir)");
5699 begin_scope();
5700 statement("return quad_shuffle_xor(value, dir + 1);");
5701 end_scope();
5702 statement("");
5703 statement("template<>");
5704 statement("inline bool spvQuadSwap(bool value, uint dir)");
5705 begin_scope();
5706 statement("return !!quad_shuffle_xor((ushort)value, dir + 1);");
5707 end_scope();
5708 statement("");
5709 statement("template<uint N>");
5710 statement("inline vec<bool, N> spvQuadSwap(vec<bool, N> value, uint dir)");
5711 begin_scope();
5712 statement("return (vec<bool, N>)quad_shuffle_xor((vec<ushort, N>)value, dir + 1);");
5713 end_scope();
5714 statement("");
5715 break;
5716
5717 case SPVFuncImplReflectScalar:
5718 // Metal does not support scalar versions of these functions.
5719 statement("template<typename T>");
5720 statement("inline T spvReflect(T i, T n)");
5721 begin_scope();
5722 statement("return i - T(2) * i * n * n;");
5723 end_scope();
5724 statement("");
5725 break;
5726
5727 case SPVFuncImplRefractScalar:
5728 // Metal does not support scalar versions of these functions.
5729 statement("template<typename T>");
5730 statement("inline T spvRefract(T i, T n, T eta)");
5731 begin_scope();
5732 statement("T NoI = n * i;");
5733 statement("T NoI2 = NoI * NoI;");
5734 statement("T k = T(1) - eta * eta * (T(1) - NoI2);");
5735 statement("if (k < T(0))");
5736 begin_scope();
5737 statement("return T(0);");
5738 end_scope();
5739 statement("else");
5740 begin_scope();
5741 statement("return eta * i - (eta * NoI + sqrt(k)) * n;");
5742 end_scope();
5743 end_scope();
5744 statement("");
5745 break;
5746
5747 case SPVFuncImplFaceForwardScalar:
5748 // Metal does not support scalar versions of these functions.
5749 statement("template<typename T>");
5750 statement("inline T spvFaceForward(T n, T i, T nref)");
5751 begin_scope();
5752 statement("return i * nref < T(0) ? n : -n;");
5753 end_scope();
5754 statement("");
5755 break;
5756
5757 case SPVFuncImplChromaReconstructNearest2Plane:
5758 statement("template<typename T, typename... LodOptions>");
5759 statement("inline vec<T, 4> spvChromaReconstructNearest(texture2d<T> plane0, texture2d<T> plane1, sampler "
5760 "samp, float2 coord, LodOptions... options)");
5761 begin_scope();
5762 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5763 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5764 statement("ycbcr.br = plane1.sample(samp, coord, spvForward<LodOptions>(options)...).rg;");
5765 statement("return ycbcr;");
5766 end_scope();
5767 statement("");
5768 break;
5769
5770 case SPVFuncImplChromaReconstructNearest3Plane:
5771 statement("template<typename T, typename... LodOptions>");
5772 statement("inline vec<T, 4> spvChromaReconstructNearest(texture2d<T> plane0, texture2d<T> plane1, "
5773 "texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
5774 begin_scope();
5775 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5776 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5777 statement("ycbcr.b = plane1.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5778 statement("ycbcr.r = plane2.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5779 statement("return ycbcr;");
5780 end_scope();
5781 statement("");
5782 break;
5783
5784 case SPVFuncImplChromaReconstructLinear422CositedEven2Plane:
5785 statement("template<typename T, typename... LodOptions>");
5786 statement("inline vec<T, 4> spvChromaReconstructLinear422CositedEven(texture2d<T> plane0, texture2d<T> "
5787 "plane1, sampler samp, float2 coord, LodOptions... options)");
5788 begin_scope();
5789 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5790 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5791 statement("if (fract(coord.x * plane1.get_width()) != 0.0)");
5792 begin_scope();
5793 statement("ycbcr.br = vec<T, 2>(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5794 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), 0.5).rg);");
5795 end_scope();
5796 statement("else");
5797 begin_scope();
5798 statement("ycbcr.br = plane1.sample(samp, coord, spvForward<LodOptions>(options)...).rg;");
5799 end_scope();
5800 statement("return ycbcr;");
5801 end_scope();
5802 statement("");
5803 break;
5804
5805 case SPVFuncImplChromaReconstructLinear422CositedEven3Plane:
5806 statement("template<typename T, typename... LodOptions>");
5807 statement("inline vec<T, 4> spvChromaReconstructLinear422CositedEven(texture2d<T> plane0, texture2d<T> "
5808 "plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
5809 begin_scope();
5810 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5811 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5812 statement("if (fract(coord.x * plane1.get_width()) != 0.0)");
5813 begin_scope();
5814 statement("ycbcr.b = T(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5815 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), 0.5).r);");
5816 statement("ycbcr.r = T(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
5817 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), 0.5).r);");
5818 end_scope();
5819 statement("else");
5820 begin_scope();
5821 statement("ycbcr.b = plane1.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5822 statement("ycbcr.r = plane2.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5823 end_scope();
5824 statement("return ycbcr;");
5825 end_scope();
5826 statement("");
5827 break;
5828
5829 case SPVFuncImplChromaReconstructLinear422Midpoint2Plane:
5830 statement("template<typename T, typename... LodOptions>");
5831 statement("inline vec<T, 4> spvChromaReconstructLinear422Midpoint(texture2d<T> plane0, texture2d<T> "
5832 "plane1, sampler samp, float2 coord, LodOptions... options)");
5833 begin_scope();
5834 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5835 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5836 statement("int2 offs = int2(fract(coord.x * plane1.get_width()) != 0.0 ? 1 : -1, 0);");
5837 statement("ycbcr.br = vec<T, 2>(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5838 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., offs), 0.25).rg);");
5839 statement("return ycbcr;");
5840 end_scope();
5841 statement("");
5842 break;
5843
5844 case SPVFuncImplChromaReconstructLinear422Midpoint3Plane:
5845 statement("template<typename T, typename... LodOptions>");
5846 statement("inline vec<T, 4> spvChromaReconstructLinear422Midpoint(texture2d<T> plane0, texture2d<T> "
5847 "plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
5848 begin_scope();
5849 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5850 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5851 statement("int2 offs = int2(fract(coord.x * plane1.get_width()) != 0.0 ? 1 : -1, 0);");
5852 statement("ycbcr.b = T(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5853 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., offs), 0.25).r);");
5854 statement("ycbcr.r = T(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
5855 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., offs), 0.25).r);");
5856 statement("return ycbcr;");
5857 end_scope();
5858 statement("");
5859 break;
5860
5861 case SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven2Plane:
5862 statement("template<typename T, typename... LodOptions>");
5863 statement("inline vec<T, 4> spvChromaReconstructLinear420XCositedEvenYCositedEven(texture2d<T> plane0, "
5864 "texture2d<T> plane1, sampler samp, float2 coord, LodOptions... options)");
5865 begin_scope();
5866 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5867 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5868 statement("float2 ab = fract(round(coord * float2(plane0.get_width(), plane0.get_height())) * 0.5);");
5869 statement("ycbcr.br = vec<T, 2>(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5870 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5871 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5872 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).rg);");
5873 statement("return ycbcr;");
5874 end_scope();
5875 statement("");
5876 break;
5877
5878 case SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven3Plane:
5879 statement("template<typename T, typename... LodOptions>");
5880 statement("inline vec<T, 4> spvChromaReconstructLinear420XCositedEvenYCositedEven(texture2d<T> plane0, "
5881 "texture2d<T> plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
5882 begin_scope();
5883 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5884 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5885 statement("float2 ab = fract(round(coord * float2(plane0.get_width(), plane0.get_height())) * 0.5);");
5886 statement("ycbcr.b = T(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5887 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5888 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5889 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
5890 statement("ycbcr.r = T(mix(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
5891 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5892 "mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5893 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
5894 statement("return ycbcr;");
5895 end_scope();
5896 statement("");
5897 break;
5898
5899 case SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven2Plane:
5900 statement("template<typename T, typename... LodOptions>");
5901 statement("inline vec<T, 4> spvChromaReconstructLinear420XMidpointYCositedEven(texture2d<T> plane0, "
5902 "texture2d<T> plane1, sampler samp, float2 coord, LodOptions... options)");
5903 begin_scope();
5904 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5905 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5906 statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0.5, "
5907 "0)) * 0.5);");
5908 statement("ycbcr.br = vec<T, 2>(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5909 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5910 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5911 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).rg);");
5912 statement("return ycbcr;");
5913 end_scope();
5914 statement("");
5915 break;
5916
5917 case SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven3Plane:
5918 statement("template<typename T, typename... LodOptions>");
5919 statement("inline vec<T, 4> spvChromaReconstructLinear420XMidpointYCositedEven(texture2d<T> plane0, "
5920 "texture2d<T> plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
5921 begin_scope();
5922 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5923 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5924 statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0.5, "
5925 "0)) * 0.5);");
5926 statement("ycbcr.b = T(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5927 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5928 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5929 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
5930 statement("ycbcr.r = T(mix(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
5931 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5932 "mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5933 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
5934 statement("return ycbcr;");
5935 end_scope();
5936 statement("");
5937 break;
5938
5939 case SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint2Plane:
5940 statement("template<typename T, typename... LodOptions>");
5941 statement("inline vec<T, 4> spvChromaReconstructLinear420XCositedEvenYMidpoint(texture2d<T> plane0, "
5942 "texture2d<T> plane1, sampler samp, float2 coord, LodOptions... options)");
5943 begin_scope();
5944 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5945 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5946 statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0, "
5947 "0.5)) * 0.5);");
5948 statement("ycbcr.br = vec<T, 2>(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5949 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5950 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5951 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).rg);");
5952 statement("return ycbcr;");
5953 end_scope();
5954 statement("");
5955 break;
5956
5957 case SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint3Plane:
5958 statement("template<typename T, typename... LodOptions>");
5959 statement("inline vec<T, 4> spvChromaReconstructLinear420XCositedEvenYMidpoint(texture2d<T> plane0, "
5960 "texture2d<T> plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
5961 begin_scope();
5962 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5963 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5964 statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0, "
5965 "0.5)) * 0.5);");
5966 statement("ycbcr.b = T(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5967 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5968 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5969 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
5970 statement("ycbcr.r = T(mix(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
5971 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5972 "mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5973 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
5974 statement("return ycbcr;");
5975 end_scope();
5976 statement("");
5977 break;
5978
5979 case SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint2Plane:
5980 statement("template<typename T, typename... LodOptions>");
5981 statement("inline vec<T, 4> spvChromaReconstructLinear420XMidpointYMidpoint(texture2d<T> plane0, "
5982 "texture2d<T> plane1, sampler samp, float2 coord, LodOptions... options)");
5983 begin_scope();
5984 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5985 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5986 statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0.5, "
5987 "0.5)) * 0.5);");
5988 statement("ycbcr.br = vec<T, 2>(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5989 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5990 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5991 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).rg);");
5992 statement("return ycbcr;");
5993 end_scope();
5994 statement("");
5995 break;
5996
5997 case SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint3Plane:
5998 statement("template<typename T, typename... LodOptions>");
5999 statement("inline vec<T, 4> spvChromaReconstructLinear420XMidpointYMidpoint(texture2d<T> plane0, "
6000 "texture2d<T> plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
6001 begin_scope();
6002 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
6003 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
6004 statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0.5, "
6005 "0.5)) * 0.5);");
6006 statement("ycbcr.b = T(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
6007 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
6008 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
6009 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
6010 statement("ycbcr.r = T(mix(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
6011 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
6012 "mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
6013 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
6014 statement("return ycbcr;");
6015 end_scope();
6016 statement("");
6017 break;
6018
6019 case SPVFuncImplExpandITUFullRange:
6020 statement("template<typename T>");
6021 statement("inline vec<T, 4> spvExpandITUFullRange(vec<T, 4> ycbcr, int n)");
6022 begin_scope();
6023 statement("ycbcr.br -= exp2(T(n-1))/(exp2(T(n))-1);");
6024 statement("return ycbcr;");
6025 end_scope();
6026 statement("");
6027 break;
6028
6029 case SPVFuncImplExpandITUNarrowRange:
6030 statement("template<typename T>");
6031 statement("inline vec<T, 4> spvExpandITUNarrowRange(vec<T, 4> ycbcr, int n)");
6032 begin_scope();
6033 statement("ycbcr.g = (ycbcr.g * (exp2(T(n)) - 1) - ldexp(T(16), n - 8))/ldexp(T(219), n - 8);");
6034 statement("ycbcr.br = (ycbcr.br * (exp2(T(n)) - 1) - ldexp(T(128), n - 8))/ldexp(T(224), n - 8);");
6035 statement("return ycbcr;");
6036 end_scope();
6037 statement("");
6038 break;
6039
6040 case SPVFuncImplConvertYCbCrBT709:
6041 statement("// cf. Khronos Data Format Specification, section 15.1.1");
6042 statement("constant float3x3 spvBT709Factors = {{1, 1, 1}, {0, -0.13397432/0.7152, 1.8556}, {1.5748, "
6043 "-0.33480248/0.7152, 0}};");
6044 statement("");
6045 statement("template<typename T>");
6046 statement("inline vec<T, 4> spvConvertYCbCrBT709(vec<T, 4> ycbcr)");
6047 begin_scope();
6048 statement("vec<T, 4> rgba;");
6049 statement("rgba.rgb = vec<T, 3>(spvBT709Factors * ycbcr.gbr);");
6050 statement("rgba.a = ycbcr.a;");
6051 statement("return rgba;");
6052 end_scope();
6053 statement("");
6054 break;
6055
6056 case SPVFuncImplConvertYCbCrBT601:
6057 statement("// cf. Khronos Data Format Specification, section 15.1.2");
6058 statement("constant float3x3 spvBT601Factors = {{1, 1, 1}, {0, -0.202008/0.587, 1.772}, {1.402, "
6059 "-0.419198/0.587, 0}};");
6060 statement("");
6061 statement("template<typename T>");
6062 statement("inline vec<T, 4> spvConvertYCbCrBT601(vec<T, 4> ycbcr)");
6063 begin_scope();
6064 statement("vec<T, 4> rgba;");
6065 statement("rgba.rgb = vec<T, 3>(spvBT601Factors * ycbcr.gbr);");
6066 statement("rgba.a = ycbcr.a;");
6067 statement("return rgba;");
6068 end_scope();
6069 statement("");
6070 break;
6071
6072 case SPVFuncImplConvertYCbCrBT2020:
6073 statement("// cf. Khronos Data Format Specification, section 15.1.3");
6074 statement("constant float3x3 spvBT2020Factors = {{1, 1, 1}, {0, -0.11156702/0.6780, 1.8814}, {1.4746, "
6075 "-0.38737742/0.6780, 0}};");
6076 statement("");
6077 statement("template<typename T>");
6078 statement("inline vec<T, 4> spvConvertYCbCrBT2020(vec<T, 4> ycbcr)");
6079 begin_scope();
6080 statement("vec<T, 4> rgba;");
6081 statement("rgba.rgb = vec<T, 3>(spvBT2020Factors * ycbcr.gbr);");
6082 statement("rgba.a = ycbcr.a;");
6083 statement("return rgba;");
6084 end_scope();
6085 statement("");
6086 break;
6087
6088 case SPVFuncImplDynamicImageSampler:
6089 statement("enum class spvFormatResolution");
6090 begin_scope();
6091 statement("_444 = 0,");
6092 statement("_422,");
6093 statement("_420");
6094 end_scope_decl();
6095 statement("");
6096 statement("enum class spvChromaFilter");
6097 begin_scope();
6098 statement("nearest = 0,");
6099 statement("linear");
6100 end_scope_decl();
6101 statement("");
6102 statement("enum class spvXChromaLocation");
6103 begin_scope();
6104 statement("cosited_even = 0,");
6105 statement("midpoint");
6106 end_scope_decl();
6107 statement("");
6108 statement("enum class spvYChromaLocation");
6109 begin_scope();
6110 statement("cosited_even = 0,");
6111 statement("midpoint");
6112 end_scope_decl();
6113 statement("");
6114 statement("enum class spvYCbCrModelConversion");
6115 begin_scope();
6116 statement("rgb_identity = 0,");
6117 statement("ycbcr_identity,");
6118 statement("ycbcr_bt_709,");
6119 statement("ycbcr_bt_601,");
6120 statement("ycbcr_bt_2020");
6121 end_scope_decl();
6122 statement("");
6123 statement("enum class spvYCbCrRange");
6124 begin_scope();
6125 statement("itu_full = 0,");
6126 statement("itu_narrow");
6127 end_scope_decl();
6128 statement("");
6129 statement("struct spvComponentBits");
6130 begin_scope();
6131 statement("constexpr explicit spvComponentBits(int v) thread : value(v) {}");
6132 statement("uchar value : 6;");
6133 end_scope_decl();
6134 statement("// A class corresponding to metal::sampler which holds sampler");
6135 statement("// Y'CbCr conversion info.");
6136 statement("struct spvYCbCrSampler");
6137 begin_scope();
6138 statement("constexpr spvYCbCrSampler() thread : val(build()) {}");
6139 statement("template<typename... Ts>");
6140 statement("constexpr spvYCbCrSampler(Ts... t) thread : val(build(t...)) {}");
6141 statement("constexpr spvYCbCrSampler(const thread spvYCbCrSampler& s) thread = default;");
6142 statement("");
6143 statement("spvFormatResolution get_resolution() const thread");
6144 begin_scope();
6145 statement("return spvFormatResolution((val & resolution_mask) >> resolution_base);");
6146 end_scope();
6147 statement("spvChromaFilter get_chroma_filter() const thread");
6148 begin_scope();
6149 statement("return spvChromaFilter((val & chroma_filter_mask) >> chroma_filter_base);");
6150 end_scope();
6151 statement("spvXChromaLocation get_x_chroma_offset() const thread");
6152 begin_scope();
6153 statement("return spvXChromaLocation((val & x_chroma_off_mask) >> x_chroma_off_base);");
6154 end_scope();
6155 statement("spvYChromaLocation get_y_chroma_offset() const thread");
6156 begin_scope();
6157 statement("return spvYChromaLocation((val & y_chroma_off_mask) >> y_chroma_off_base);");
6158 end_scope();
6159 statement("spvYCbCrModelConversion get_ycbcr_model() const thread");
6160 begin_scope();
6161 statement("return spvYCbCrModelConversion((val & ycbcr_model_mask) >> ycbcr_model_base);");
6162 end_scope();
6163 statement("spvYCbCrRange get_ycbcr_range() const thread");
6164 begin_scope();
6165 statement("return spvYCbCrRange((val & ycbcr_range_mask) >> ycbcr_range_base);");
6166 end_scope();
6167 statement("int get_bpc() const thread { return (val & bpc_mask) >> bpc_base; }");
6168 statement("");
6169 statement("private:");
6170 statement("ushort val;");
6171 statement("");
6172 statement("constexpr static constant ushort resolution_bits = 2;");
6173 statement("constexpr static constant ushort chroma_filter_bits = 2;");
6174 statement("constexpr static constant ushort x_chroma_off_bit = 1;");
6175 statement("constexpr static constant ushort y_chroma_off_bit = 1;");
6176 statement("constexpr static constant ushort ycbcr_model_bits = 3;");
6177 statement("constexpr static constant ushort ycbcr_range_bit = 1;");
6178 statement("constexpr static constant ushort bpc_bits = 6;");
6179 statement("");
6180 statement("constexpr static constant ushort resolution_base = 0;");
6181 statement("constexpr static constant ushort chroma_filter_base = 2;");
6182 statement("constexpr static constant ushort x_chroma_off_base = 4;");
6183 statement("constexpr static constant ushort y_chroma_off_base = 5;");
6184 statement("constexpr static constant ushort ycbcr_model_base = 6;");
6185 statement("constexpr static constant ushort ycbcr_range_base = 9;");
6186 statement("constexpr static constant ushort bpc_base = 10;");
6187 statement("");
6188 statement(
6189 "constexpr static constant ushort resolution_mask = ((1 << resolution_bits) - 1) << resolution_base;");
6190 statement("constexpr static constant ushort chroma_filter_mask = ((1 << chroma_filter_bits) - 1) << "
6191 "chroma_filter_base;");
6192 statement("constexpr static constant ushort x_chroma_off_mask = ((1 << x_chroma_off_bit) - 1) << "
6193 "x_chroma_off_base;");
6194 statement("constexpr static constant ushort y_chroma_off_mask = ((1 << y_chroma_off_bit) - 1) << "
6195 "y_chroma_off_base;");
6196 statement("constexpr static constant ushort ycbcr_model_mask = ((1 << ycbcr_model_bits) - 1) << "
6197 "ycbcr_model_base;");
6198 statement("constexpr static constant ushort ycbcr_range_mask = ((1 << ycbcr_range_bit) - 1) << "
6199 "ycbcr_range_base;");
6200 statement("constexpr static constant ushort bpc_mask = ((1 << bpc_bits) - 1) << bpc_base;");
6201 statement("");
6202 statement("static constexpr ushort build()");
6203 begin_scope();
6204 statement("return 0;");
6205 end_scope();
6206 statement("");
6207 statement("template<typename... Ts>");
6208 statement("static constexpr ushort build(spvFormatResolution res, Ts... t)");
6209 begin_scope();
6210 statement("return (ushort(res) << resolution_base) | (build(t...) & ~resolution_mask);");
6211 end_scope();
6212 statement("");
6213 statement("template<typename... Ts>");
6214 statement("static constexpr ushort build(spvChromaFilter filt, Ts... t)");
6215 begin_scope();
6216 statement("return (ushort(filt) << chroma_filter_base) | (build(t...) & ~chroma_filter_mask);");
6217 end_scope();
6218 statement("");
6219 statement("template<typename... Ts>");
6220 statement("static constexpr ushort build(spvXChromaLocation loc, Ts... t)");
6221 begin_scope();
6222 statement("return (ushort(loc) << x_chroma_off_base) | (build(t...) & ~x_chroma_off_mask);");
6223 end_scope();
6224 statement("");
6225 statement("template<typename... Ts>");
6226 statement("static constexpr ushort build(spvYChromaLocation loc, Ts... t)");
6227 begin_scope();
6228 statement("return (ushort(loc) << y_chroma_off_base) | (build(t...) & ~y_chroma_off_mask);");
6229 end_scope();
6230 statement("");
6231 statement("template<typename... Ts>");
6232 statement("static constexpr ushort build(spvYCbCrModelConversion model, Ts... t)");
6233 begin_scope();
6234 statement("return (ushort(model) << ycbcr_model_base) | (build(t...) & ~ycbcr_model_mask);");
6235 end_scope();
6236 statement("");
6237 statement("template<typename... Ts>");
6238 statement("static constexpr ushort build(spvYCbCrRange range, Ts... t)");
6239 begin_scope();
6240 statement("return (ushort(range) << ycbcr_range_base) | (build(t...) & ~ycbcr_range_mask);");
6241 end_scope();
6242 statement("");
6243 statement("template<typename... Ts>");
6244 statement("static constexpr ushort build(spvComponentBits bpc, Ts... t)");
6245 begin_scope();
6246 statement("return (ushort(bpc.value) << bpc_base) | (build(t...) & ~bpc_mask);");
6247 end_scope();
6248 end_scope_decl();
6249 statement("");
6250 statement("// A class which can hold up to three textures and a sampler, including");
6251 statement("// Y'CbCr conversion info, used to pass combined image-samplers");
6252 statement("// dynamically to functions.");
6253 statement("template<typename T>");
6254 statement("struct spvDynamicImageSampler");
6255 begin_scope();
6256 statement("texture2d<T> plane0;");
6257 statement("texture2d<T> plane1;");
6258 statement("texture2d<T> plane2;");
6259 statement("sampler samp;");
6260 statement("spvYCbCrSampler ycbcr_samp;");
6261 statement("uint swizzle = 0;");
6262 statement("");
6263 if (msl_options.swizzle_texture_samples)
6264 {
6265 statement("constexpr spvDynamicImageSampler(texture2d<T> tex, sampler samp, uint sw) thread :");
6266 statement(" plane0(tex), samp(samp), swizzle(sw) {}");
6267 }
6268 else
6269 {
6270 statement("constexpr spvDynamicImageSampler(texture2d<T> tex, sampler samp) thread :");
6271 statement(" plane0(tex), samp(samp) {}");
6272 }
6273 statement("constexpr spvDynamicImageSampler(texture2d<T> tex, sampler samp, spvYCbCrSampler ycbcr_samp, "
6274 "uint sw) thread :");
6275 statement(" plane0(tex), samp(samp), ycbcr_samp(ycbcr_samp), swizzle(sw) {}");
6276 statement("constexpr spvDynamicImageSampler(texture2d<T> plane0, texture2d<T> plane1,");
6277 statement(" sampler samp, spvYCbCrSampler ycbcr_samp, uint sw) thread :");
6278 statement(" plane0(plane0), plane1(plane1), samp(samp), ycbcr_samp(ycbcr_samp), swizzle(sw) {}");
6279 statement(
6280 "constexpr spvDynamicImageSampler(texture2d<T> plane0, texture2d<T> plane1, texture2d<T> plane2,");
6281 statement(" sampler samp, spvYCbCrSampler ycbcr_samp, uint sw) thread :");
6282 statement(" plane0(plane0), plane1(plane1), plane2(plane2), samp(samp), ycbcr_samp(ycbcr_samp), "
6283 "swizzle(sw) {}");
6284 statement("");
6285 // XXX This is really hard to follow... I've left comments to make it a bit easier.
6286 statement("template<typename... LodOptions>");
6287 statement("vec<T, 4> do_sample(float2 coord, LodOptions... options) const thread");
6288 begin_scope();
6289 statement("if (!is_null_texture(plane1))");
6290 begin_scope();
6291 statement("if (ycbcr_samp.get_resolution() == spvFormatResolution::_444 ||");
6292 statement(" ycbcr_samp.get_chroma_filter() == spvChromaFilter::nearest)");
6293 begin_scope();
6294 statement("if (!is_null_texture(plane2))");
6295 statement(" return spvChromaReconstructNearest(plane0, plane1, plane2, samp, coord,");
6296 statement(" spvForward<LodOptions>(options)...);");
6297 statement(
6298 "return spvChromaReconstructNearest(plane0, plane1, samp, coord, spvForward<LodOptions>(options)...);");
6299 end_scope(); // if (resolution == 422 || chroma_filter == nearest)
6300 statement("switch (ycbcr_samp.get_resolution())");
6301 begin_scope();
6302 statement("case spvFormatResolution::_444: break;");
6303 statement("case spvFormatResolution::_422:");
6304 begin_scope();
6305 statement("switch (ycbcr_samp.get_x_chroma_offset())");
6306 begin_scope();
6307 statement("case spvXChromaLocation::cosited_even:");
6308 statement(" if (!is_null_texture(plane2))");
6309 statement(" return spvChromaReconstructLinear422CositedEven(");
6310 statement(" plane0, plane1, plane2, samp,");
6311 statement(" coord, spvForward<LodOptions>(options)...);");
6312 statement(" return spvChromaReconstructLinear422CositedEven(");
6313 statement(" plane0, plane1, samp, coord,");
6314 statement(" spvForward<LodOptions>(options)...);");
6315 statement("case spvXChromaLocation::midpoint:");
6316 statement(" if (!is_null_texture(plane2))");
6317 statement(" return spvChromaReconstructLinear422Midpoint(");
6318 statement(" plane0, plane1, plane2, samp,");
6319 statement(" coord, spvForward<LodOptions>(options)...);");
6320 statement(" return spvChromaReconstructLinear422Midpoint(");
6321 statement(" plane0, plane1, samp, coord,");
6322 statement(" spvForward<LodOptions>(options)...);");
6323 end_scope(); // switch (x_chroma_offset)
6324 end_scope(); // case 422:
6325 statement("case spvFormatResolution::_420:");
6326 begin_scope();
6327 statement("switch (ycbcr_samp.get_x_chroma_offset())");
6328 begin_scope();
6329 statement("case spvXChromaLocation::cosited_even:");
6330 begin_scope();
6331 statement("switch (ycbcr_samp.get_y_chroma_offset())");
6332 begin_scope();
6333 statement("case spvYChromaLocation::cosited_even:");
6334 statement(" if (!is_null_texture(plane2))");
6335 statement(" return spvChromaReconstructLinear420XCositedEvenYCositedEven(");
6336 statement(" plane0, plane1, plane2, samp,");
6337 statement(" coord, spvForward<LodOptions>(options)...);");
6338 statement(" return spvChromaReconstructLinear420XCositedEvenYCositedEven(");
6339 statement(" plane0, plane1, samp, coord,");
6340 statement(" spvForward<LodOptions>(options)...);");
6341 statement("case spvYChromaLocation::midpoint:");
6342 statement(" if (!is_null_texture(plane2))");
6343 statement(" return spvChromaReconstructLinear420XCositedEvenYMidpoint(");
6344 statement(" plane0, plane1, plane2, samp,");
6345 statement(" coord, spvForward<LodOptions>(options)...);");
6346 statement(" return spvChromaReconstructLinear420XCositedEvenYMidpoint(");
6347 statement(" plane0, plane1, samp, coord,");
6348 statement(" spvForward<LodOptions>(options)...);");
6349 end_scope(); // switch (y_chroma_offset)
6350 end_scope(); // case x::cosited_even:
6351 statement("case spvXChromaLocation::midpoint:");
6352 begin_scope();
6353 statement("switch (ycbcr_samp.get_y_chroma_offset())");
6354 begin_scope();
6355 statement("case spvYChromaLocation::cosited_even:");
6356 statement(" if (!is_null_texture(plane2))");
6357 statement(" return spvChromaReconstructLinear420XMidpointYCositedEven(");
6358 statement(" plane0, plane1, plane2, samp,");
6359 statement(" coord, spvForward<LodOptions>(options)...);");
6360 statement(" return spvChromaReconstructLinear420XMidpointYCositedEven(");
6361 statement(" plane0, plane1, samp, coord,");
6362 statement(" spvForward<LodOptions>(options)...);");
6363 statement("case spvYChromaLocation::midpoint:");
6364 statement(" if (!is_null_texture(plane2))");
6365 statement(" return spvChromaReconstructLinear420XMidpointYMidpoint(");
6366 statement(" plane0, plane1, plane2, samp,");
6367 statement(" coord, spvForward<LodOptions>(options)...);");
6368 statement(" return spvChromaReconstructLinear420XMidpointYMidpoint(");
6369 statement(" plane0, plane1, samp, coord,");
6370 statement(" spvForward<LodOptions>(options)...);");
6371 end_scope(); // switch (y_chroma_offset)
6372 end_scope(); // case x::midpoint
6373 end_scope(); // switch (x_chroma_offset)
6374 end_scope(); // case 420:
6375 end_scope(); // switch (resolution)
6376 end_scope(); // if (multiplanar)
6377 statement("return plane0.sample(samp, coord, spvForward<LodOptions>(options)...);");
6378 end_scope(); // do_sample()
6379 statement("template <typename... LodOptions>");
6380 statement("vec<T, 4> sample(float2 coord, LodOptions... options) const thread");
6381 begin_scope();
6382 statement(
6383 "vec<T, 4> s = spvTextureSwizzle(do_sample(coord, spvForward<LodOptions>(options)...), swizzle);");
6384 statement("if (ycbcr_samp.get_ycbcr_model() == spvYCbCrModelConversion::rgb_identity)");
6385 statement(" return s;");
6386 statement("");
6387 statement("switch (ycbcr_samp.get_ycbcr_range())");
6388 begin_scope();
6389 statement("case spvYCbCrRange::itu_full:");
6390 statement(" s = spvExpandITUFullRange(s, ycbcr_samp.get_bpc());");
6391 statement(" break;");
6392 statement("case spvYCbCrRange::itu_narrow:");
6393 statement(" s = spvExpandITUNarrowRange(s, ycbcr_samp.get_bpc());");
6394 statement(" break;");
6395 end_scope();
6396 statement("");
6397 statement("switch (ycbcr_samp.get_ycbcr_model())");
6398 begin_scope();
6399 statement("case spvYCbCrModelConversion::rgb_identity:"); // Silence Clang warning
6400 statement("case spvYCbCrModelConversion::ycbcr_identity:");
6401 statement(" return s;");
6402 statement("case spvYCbCrModelConversion::ycbcr_bt_709:");
6403 statement(" return spvConvertYCbCrBT709(s);");
6404 statement("case spvYCbCrModelConversion::ycbcr_bt_601:");
6405 statement(" return spvConvertYCbCrBT601(s);");
6406 statement("case spvYCbCrModelConversion::ycbcr_bt_2020:");
6407 statement(" return spvConvertYCbCrBT2020(s);");
6408 end_scope();
6409 end_scope();
6410 statement("");
6411 // Sampler Y'CbCr conversion forbids offsets.
6412 statement("vec<T, 4> sample(float2 coord, int2 offset) const thread");
6413 begin_scope();
6414 if (msl_options.swizzle_texture_samples)
6415 statement("return spvTextureSwizzle(plane0.sample(samp, coord, offset), swizzle);");
6416 else
6417 statement("return plane0.sample(samp, coord, offset);");
6418 end_scope();
6419 statement("template<typename lod_options>");
6420 statement("vec<T, 4> sample(float2 coord, lod_options options, int2 offset) const thread");
6421 begin_scope();
6422 if (msl_options.swizzle_texture_samples)
6423 statement("return spvTextureSwizzle(plane0.sample(samp, coord, options, offset), swizzle);");
6424 else
6425 statement("return plane0.sample(samp, coord, options, offset);");
6426 end_scope();
6427 statement("#if __HAVE_MIN_LOD_CLAMP__");
6428 statement("vec<T, 4> sample(float2 coord, bias b, min_lod_clamp min_lod, int2 offset) const thread");
6429 begin_scope();
6430 statement("return plane0.sample(samp, coord, b, min_lod, offset);");
6431 end_scope();
6432 statement(
6433 "vec<T, 4> sample(float2 coord, gradient2d grad, min_lod_clamp min_lod, int2 offset) const thread");
6434 begin_scope();
6435 statement("return plane0.sample(samp, coord, grad, min_lod, offset);");
6436 end_scope();
6437 statement("#endif");
6438 statement("");
6439 // Y'CbCr conversion forbids all operations but sampling.
6440 statement("vec<T, 4> read(uint2 coord, uint lod = 0) const thread");
6441 begin_scope();
6442 statement("return plane0.read(coord, lod);");
6443 end_scope();
6444 statement("");
6445 statement("vec<T, 4> gather(float2 coord, int2 offset = int2(0), component c = component::x) const thread");
6446 begin_scope();
6447 if (msl_options.swizzle_texture_samples)
6448 statement("return spvGatherSwizzle(plane0, samp, swizzle, c, coord, offset);");
6449 else
6450 statement("return plane0.gather(samp, coord, offset, c);");
6451 end_scope();
6452 end_scope_decl();
6453 statement("");
6454
6455 default:
6456 break;
6457 }
6458 }
6459 }
6460
inject_top_level_storage_qualifier(const string & expr,const string & qualifier)6461 static string inject_top_level_storage_qualifier(const string &expr, const string &qualifier)
6462 {
6463 // Easier to do this through text munging since the qualifier does not exist in the type system at all,
6464 // and plumbing in all that information is not very helpful.
6465 size_t last_reference = expr.find_last_of('&');
6466 size_t last_pointer = expr.find_last_of('*');
6467 size_t last_significant = string::npos;
6468
6469 if (last_reference == string::npos)
6470 last_significant = last_pointer;
6471 else if (last_pointer == string::npos)
6472 last_significant = last_reference;
6473 else
6474 last_significant = std::max(last_reference, last_pointer);
6475
6476 if (last_significant == string::npos)
6477 return join(qualifier, " ", expr);
6478 else
6479 {
6480 return join(expr.substr(0, last_significant + 1), " ",
6481 qualifier, expr.substr(last_significant + 1, string::npos));
6482 }
6483 }
6484
6485 // Undefined global memory is not allowed in MSL.
6486 // Declare constant and init to zeros. Use {}, as global constructors can break Metal.
declare_undefined_values()6487 void CompilerMSL::declare_undefined_values()
6488 {
6489 bool emitted = false;
6490 ir.for_each_typed_id<SPIRUndef>([&](uint32_t, SPIRUndef &undef) {
6491 auto &type = this->get<SPIRType>(undef.basetype);
6492 // OpUndef can be void for some reason ...
6493 if (type.basetype == SPIRType::Void)
6494 return;
6495
6496 statement(inject_top_level_storage_qualifier(
6497 variable_decl(type, to_name(undef.self), undef.self),
6498 "constant"),
6499 " = {};");
6500 emitted = true;
6501 });
6502
6503 if (emitted)
6504 statement("");
6505 }
6506
declare_constant_arrays()6507 void CompilerMSL::declare_constant_arrays()
6508 {
6509 bool fully_inlined = ir.ids_for_type[TypeFunction].size() == 1;
6510
6511 // MSL cannot declare arrays inline (except when declaring a variable), so we must move them out to
6512 // global constants directly, so we are able to use constants as variable expressions.
6513 bool emitted = false;
6514
6515 ir.for_each_typed_id<SPIRConstant>([&](uint32_t, SPIRConstant &c) {
6516 if (c.specialization)
6517 return;
6518
6519 auto &type = this->get<SPIRType>(c.constant_type);
6520 // Constant arrays of non-primitive types (i.e. matrices) won't link properly into Metal libraries.
6521 // FIXME: However, hoisting constants to main() means we need to pass down constant arrays to leaf functions if they are used there.
6522 // If there are multiple functions in the module, drop this case to avoid breaking use cases which do not need to
6523 // link into Metal libraries. This is hacky.
6524 if (!type.array.empty() && (!fully_inlined || is_scalar(type) || is_vector(type)))
6525 {
6526 auto name = to_name(c.self);
6527 statement(inject_top_level_storage_qualifier(variable_decl(type, name), "constant"),
6528 " = ", constant_expression(c), ";");
6529 emitted = true;
6530 }
6531 });
6532
6533 if (emitted)
6534 statement("");
6535 }
6536
6537 // Constant arrays of non-primitive types (i.e. matrices) won't link properly into Metal libraries
declare_complex_constant_arrays()6538 void CompilerMSL::declare_complex_constant_arrays()
6539 {
6540 // If we do not have a fully inlined module, we did not opt in to
6541 // declaring constant arrays of complex types. See CompilerMSL::declare_constant_arrays().
6542 bool fully_inlined = ir.ids_for_type[TypeFunction].size() == 1;
6543 if (!fully_inlined)
6544 return;
6545
6546 // MSL cannot declare arrays inline (except when declaring a variable), so we must move them out to
6547 // global constants directly, so we are able to use constants as variable expressions.
6548 bool emitted = false;
6549
6550 ir.for_each_typed_id<SPIRConstant>([&](uint32_t, SPIRConstant &c) {
6551 if (c.specialization)
6552 return;
6553
6554 auto &type = this->get<SPIRType>(c.constant_type);
6555 if (!type.array.empty() && !(is_scalar(type) || is_vector(type)))
6556 {
6557 auto name = to_name(c.self);
6558 statement("", variable_decl(type, name), " = ", constant_expression(c), ";");
6559 emitted = true;
6560 }
6561 });
6562
6563 if (emitted)
6564 statement("");
6565 }
6566
emit_resources()6567 void CompilerMSL::emit_resources()
6568 {
6569 declare_constant_arrays();
6570 declare_undefined_values();
6571
6572 // Emit the special [[stage_in]] and [[stage_out]] interface blocks which we created.
6573 emit_interface_block(stage_out_var_id);
6574 emit_interface_block(patch_stage_out_var_id);
6575 emit_interface_block(stage_in_var_id);
6576 emit_interface_block(patch_stage_in_var_id);
6577 }
6578
6579 // Emit declarations for the specialization Metal function constants
emit_specialization_constants_and_structs()6580 void CompilerMSL::emit_specialization_constants_and_structs()
6581 {
6582 SpecializationConstant wg_x, wg_y, wg_z;
6583 ID workgroup_size_id = get_work_group_size_specialization_constants(wg_x, wg_y, wg_z);
6584 bool emitted = false;
6585
6586 unordered_set<uint32_t> declared_structs;
6587 unordered_set<uint32_t> aligned_structs;
6588
6589 // First, we need to deal with scalar block layout.
6590 // It is possible that a struct may have to be placed at an alignment which does not match the innate alignment of the struct itself.
6591 // In that case, if such a case exists for a struct, we must force that all elements of the struct become packed_ types.
6592 // This makes the struct alignment as small as physically possible.
6593 // When we actually align the struct later, we can insert padding as necessary to make the packed members behave like normally aligned types.
6594 ir.for_each_typed_id<SPIRType>([&](uint32_t type_id, const SPIRType &type) {
6595 if (type.basetype == SPIRType::Struct &&
6596 has_extended_decoration(type_id, SPIRVCrossDecorationBufferBlockRepacked))
6597 mark_scalar_layout_structs(type);
6598 });
6599
6600 bool builtin_block_type_is_required = false;
6601 // Very special case. If gl_PerVertex is initialized as an array (tessellation)
6602 // we have to potentially emit the gl_PerVertex struct type so that we can emit a constant LUT.
6603 ir.for_each_typed_id<SPIRConstant>([&](uint32_t, SPIRConstant &c) {
6604 auto &type = this->get<SPIRType>(c.constant_type);
6605 if (is_array(type) && has_decoration(type.self, DecorationBlock) && is_builtin_type(type))
6606 builtin_block_type_is_required = true;
6607 });
6608
6609 // Very particular use of the soft loop lock.
6610 // align_struct may need to create custom types on the fly, but we don't care about
6611 // these types for purpose of iterating over them in ir.ids_for_type and friends.
6612 auto loop_lock = ir.create_loop_soft_lock();
6613
6614 for (auto &id_ : ir.ids_for_constant_or_type)
6615 {
6616 auto &id = ir.ids[id_];
6617
6618 if (id.get_type() == TypeConstant)
6619 {
6620 auto &c = id.get<SPIRConstant>();
6621
6622 if (c.self == workgroup_size_id)
6623 {
6624 // TODO: This can be expressed as a [[threads_per_threadgroup]] input semantic, but we need to know
6625 // the work group size at compile time in SPIR-V, and [[threads_per_threadgroup]] would need to be passed around as a global.
6626 // The work group size may be a specialization constant.
6627 statement("constant uint3 ", builtin_to_glsl(BuiltInWorkgroupSize, StorageClassWorkgroup),
6628 " [[maybe_unused]] = ", constant_expression(get<SPIRConstant>(workgroup_size_id)), ";");
6629 emitted = true;
6630 }
6631 else if (c.specialization)
6632 {
6633 auto &type = get<SPIRType>(c.constant_type);
6634 string sc_type_name = type_to_glsl(type);
6635 string sc_name = to_name(c.self);
6636 string sc_tmp_name = sc_name + "_tmp";
6637
6638 // Function constants are only supported in MSL 1.2 and later.
6639 // If we don't support it just declare the "default" directly.
6640 // This "default" value can be overridden to the true specialization constant by the API user.
6641 // Specialization constants which are used as array length expressions cannot be function constants in MSL,
6642 // so just fall back to macros.
6643 if (msl_options.supports_msl_version(1, 2) && has_decoration(c.self, DecorationSpecId) &&
6644 !c.is_used_as_array_length)
6645 {
6646 uint32_t constant_id = get_decoration(c.self, DecorationSpecId);
6647 // Only scalar, non-composite values can be function constants.
6648 statement("constant ", sc_type_name, " ", sc_tmp_name, " [[function_constant(", constant_id,
6649 ")]];");
6650 statement("constant ", sc_type_name, " ", sc_name, " = is_function_constant_defined(", sc_tmp_name,
6651 ") ? ", sc_tmp_name, " : ", constant_expression(c), ";");
6652 }
6653 else if (has_decoration(c.self, DecorationSpecId))
6654 {
6655 // Fallback to macro overrides.
6656 c.specialization_constant_macro_name =
6657 constant_value_macro_name(get_decoration(c.self, DecorationSpecId));
6658
6659 statement("#ifndef ", c.specialization_constant_macro_name);
6660 statement("#define ", c.specialization_constant_macro_name, " ", constant_expression(c));
6661 statement("#endif");
6662 statement("constant ", sc_type_name, " ", sc_name, " = ", c.specialization_constant_macro_name,
6663 ";");
6664 }
6665 else
6666 {
6667 // Composite specialization constants must be built from other specialization constants.
6668 statement("constant ", sc_type_name, " ", sc_name, " = ", constant_expression(c), ";");
6669 }
6670 emitted = true;
6671 }
6672 }
6673 else if (id.get_type() == TypeConstantOp)
6674 {
6675 auto &c = id.get<SPIRConstantOp>();
6676 auto &type = get<SPIRType>(c.basetype);
6677 auto name = to_name(c.self);
6678 statement("constant ", variable_decl(type, name), " = ", constant_op_expression(c), ";");
6679 emitted = true;
6680 }
6681 else if (id.get_type() == TypeType)
6682 {
6683 // Output non-builtin interface structs. These include local function structs
6684 // and structs nested within uniform and read-write buffers.
6685 auto &type = id.get<SPIRType>();
6686 TypeID type_id = type.self;
6687
6688 bool is_struct = (type.basetype == SPIRType::Struct) && type.array.empty() && !type.pointer;
6689 bool is_block =
6690 has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock);
6691
6692 bool is_builtin_block = is_block && is_builtin_type(type);
6693 bool is_declarable_struct = is_struct && (!is_builtin_block || builtin_block_type_is_required);
6694
6695 // We'll declare this later.
6696 if (stage_out_var_id && get_stage_out_struct_type().self == type_id)
6697 is_declarable_struct = false;
6698 if (patch_stage_out_var_id && get_patch_stage_out_struct_type().self == type_id)
6699 is_declarable_struct = false;
6700 if (stage_in_var_id && get_stage_in_struct_type().self == type_id)
6701 is_declarable_struct = false;
6702 if (patch_stage_in_var_id && get_patch_stage_in_struct_type().self == type_id)
6703 is_declarable_struct = false;
6704
6705 // Special case. Declare builtin struct anyways if we need to emit a threadgroup version of it.
6706 if (stage_out_masked_builtin_type_id == type_id)
6707 is_declarable_struct = true;
6708
6709 // Align and emit declarable structs...but avoid declaring each more than once.
6710 if (is_declarable_struct && declared_structs.count(type_id) == 0)
6711 {
6712 if (emitted)
6713 statement("");
6714 emitted = false;
6715
6716 declared_structs.insert(type_id);
6717
6718 if (has_extended_decoration(type_id, SPIRVCrossDecorationBufferBlockRepacked))
6719 align_struct(type, aligned_structs);
6720
6721 // Make sure we declare the underlying struct type, and not the "decorated" type with pointers, etc.
6722 emit_struct(get<SPIRType>(type_id));
6723 }
6724 }
6725 }
6726
6727 if (emitted)
6728 statement("");
6729 }
6730
emit_binary_unord_op(uint32_t result_type,uint32_t result_id,uint32_t op0,uint32_t op1,const char * op)6731 void CompilerMSL::emit_binary_unord_op(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1,
6732 const char *op)
6733 {
6734 bool forward = should_forward(op0) && should_forward(op1);
6735 emit_op(result_type, result_id,
6736 join("(isunordered(", to_enclosed_unpacked_expression(op0), ", ", to_enclosed_unpacked_expression(op1),
6737 ") || ", to_enclosed_unpacked_expression(op0), " ", op, " ", to_enclosed_unpacked_expression(op1),
6738 ")"),
6739 forward);
6740
6741 inherit_expression_dependencies(result_id, op0);
6742 inherit_expression_dependencies(result_id, op1);
6743 }
6744
emit_tessellation_io_load(uint32_t result_type_id,uint32_t id,uint32_t ptr)6745 bool CompilerMSL::emit_tessellation_io_load(uint32_t result_type_id, uint32_t id, uint32_t ptr)
6746 {
6747 auto &ptr_type = expression_type(ptr);
6748 auto &result_type = get<SPIRType>(result_type_id);
6749 if (ptr_type.storage != StorageClassInput && ptr_type.storage != StorageClassOutput)
6750 return false;
6751 if (ptr_type.storage == StorageClassOutput && get_execution_model() == ExecutionModelTessellationEvaluation)
6752 return false;
6753
6754 if (has_decoration(ptr, DecorationPatch))
6755 return false;
6756 bool ptr_is_io_variable = ir.ids[ptr].get_type() == TypeVariable;
6757
6758 bool flattened_io = variable_storage_requires_stage_io(ptr_type.storage);
6759
6760 bool flat_data_type = flattened_io &&
6761 (is_matrix(result_type) || is_array(result_type) || result_type.basetype == SPIRType::Struct);
6762
6763 // Edge case, even with multi-patch workgroups, we still need to unroll load
6764 // if we're loading control points directly.
6765 if (ptr_is_io_variable && is_array(result_type))
6766 flat_data_type = true;
6767
6768 if (!flat_data_type)
6769 return false;
6770
6771 // Now, we must unflatten a composite type and take care of interleaving array access with gl_in/gl_out.
6772 // Lots of painful code duplication since we *really* should not unroll these kinds of loads in entry point fixup
6773 // unless we're forced to do this when the code is emitting inoptimal OpLoads.
6774 string expr;
6775
6776 uint32_t interface_index = get_extended_decoration(ptr, SPIRVCrossDecorationInterfaceMemberIndex);
6777 auto *var = maybe_get_backing_variable(ptr);
6778 auto &expr_type = get_pointee_type(ptr_type.self);
6779
6780 const auto &iface_type = expression_type(stage_in_ptr_var_id);
6781
6782 if (!flattened_io)
6783 {
6784 // Simplest case for multi-patch workgroups, just unroll array as-is.
6785 if (interface_index == uint32_t(-1))
6786 return false;
6787
6788 expr += type_to_glsl(result_type) + "({ ";
6789 uint32_t num_control_points = to_array_size_literal(result_type, uint32_t(result_type.array.size()) - 1);
6790
6791 for (uint32_t i = 0; i < num_control_points; i++)
6792 {
6793 const uint32_t indices[2] = { i, interface_index };
6794 AccessChainMeta meta;
6795 expr += access_chain_internal(stage_in_ptr_var_id, indices, 2,
6796 ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
6797 if (i + 1 < num_control_points)
6798 expr += ", ";
6799 }
6800 expr += " })";
6801 }
6802 else if (result_type.array.size() > 2)
6803 {
6804 SPIRV_CROSS_THROW("Cannot load tessellation IO variables with more than 2 dimensions.");
6805 }
6806 else if (result_type.array.size() == 2)
6807 {
6808 if (!ptr_is_io_variable)
6809 SPIRV_CROSS_THROW("Loading an array-of-array must be loaded directly from an IO variable.");
6810 if (interface_index == uint32_t(-1))
6811 SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
6812 if (result_type.basetype == SPIRType::Struct || is_matrix(result_type))
6813 SPIRV_CROSS_THROW("Cannot load array-of-array of composite type in tessellation IO.");
6814
6815 expr += type_to_glsl(result_type) + "({ ";
6816 uint32_t num_control_points = to_array_size_literal(result_type, 1);
6817 uint32_t base_interface_index = interface_index;
6818
6819 auto &sub_type = get<SPIRType>(result_type.parent_type);
6820
6821 for (uint32_t i = 0; i < num_control_points; i++)
6822 {
6823 expr += type_to_glsl(sub_type) + "({ ";
6824 interface_index = base_interface_index;
6825 uint32_t array_size = to_array_size_literal(result_type, 0);
6826 for (uint32_t j = 0; j < array_size; j++, interface_index++)
6827 {
6828 const uint32_t indices[2] = { i, interface_index };
6829
6830 AccessChainMeta meta;
6831 expr += access_chain_internal(stage_in_ptr_var_id, indices, 2,
6832 ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
6833 if (!is_matrix(sub_type) && sub_type.basetype != SPIRType::Struct &&
6834 expr_type.vecsize > sub_type.vecsize)
6835 expr += vector_swizzle(sub_type.vecsize, 0);
6836
6837 if (j + 1 < array_size)
6838 expr += ", ";
6839 }
6840 expr += " })";
6841 if (i + 1 < num_control_points)
6842 expr += ", ";
6843 }
6844 expr += " })";
6845 }
6846 else if (result_type.basetype == SPIRType::Struct)
6847 {
6848 bool is_array_of_struct = is_array(result_type);
6849 if (is_array_of_struct && !ptr_is_io_variable)
6850 SPIRV_CROSS_THROW("Loading array of struct from IO variable must come directly from IO variable.");
6851
6852 uint32_t num_control_points = 1;
6853 if (is_array_of_struct)
6854 {
6855 num_control_points = to_array_size_literal(result_type, 0);
6856 expr += type_to_glsl(result_type) + "({ ";
6857 }
6858
6859 auto &struct_type = is_array_of_struct ? get<SPIRType>(result_type.parent_type) : result_type;
6860 assert(struct_type.array.empty());
6861
6862 for (uint32_t i = 0; i < num_control_points; i++)
6863 {
6864 expr += type_to_glsl(struct_type) + "{ ";
6865 for (uint32_t j = 0; j < uint32_t(struct_type.member_types.size()); j++)
6866 {
6867 // The base interface index is stored per variable for structs.
6868 if (var)
6869 {
6870 interface_index =
6871 get_extended_member_decoration(var->self, j, SPIRVCrossDecorationInterfaceMemberIndex);
6872 }
6873
6874 if (interface_index == uint32_t(-1))
6875 SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
6876
6877 const auto &mbr_type = get<SPIRType>(struct_type.member_types[j]);
6878 const auto &expr_mbr_type = get<SPIRType>(expr_type.member_types[j]);
6879 if (is_matrix(mbr_type) && ptr_type.storage == StorageClassInput)
6880 {
6881 expr += type_to_glsl(mbr_type) + "(";
6882 for (uint32_t k = 0; k < mbr_type.columns; k++, interface_index++)
6883 {
6884 if (is_array_of_struct)
6885 {
6886 const uint32_t indices[2] = { i, interface_index };
6887 AccessChainMeta meta;
6888 expr += access_chain_internal(
6889 stage_in_ptr_var_id, indices, 2,
6890 ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
6891 }
6892 else
6893 expr += to_expression(ptr) + "." + to_member_name(iface_type, interface_index);
6894 if (expr_mbr_type.vecsize > mbr_type.vecsize)
6895 expr += vector_swizzle(mbr_type.vecsize, 0);
6896
6897 if (k + 1 < mbr_type.columns)
6898 expr += ", ";
6899 }
6900 expr += ")";
6901 }
6902 else if (is_array(mbr_type))
6903 {
6904 expr += type_to_glsl(mbr_type) + "({ ";
6905 uint32_t array_size = to_array_size_literal(mbr_type, 0);
6906 for (uint32_t k = 0; k < array_size; k++, interface_index++)
6907 {
6908 if (is_array_of_struct)
6909 {
6910 const uint32_t indices[2] = { i, interface_index };
6911 AccessChainMeta meta;
6912 expr += access_chain_internal(
6913 stage_in_ptr_var_id, indices, 2,
6914 ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
6915 }
6916 else
6917 expr += to_expression(ptr) + "." + to_member_name(iface_type, interface_index);
6918 if (expr_mbr_type.vecsize > mbr_type.vecsize)
6919 expr += vector_swizzle(mbr_type.vecsize, 0);
6920
6921 if (k + 1 < array_size)
6922 expr += ", ";
6923 }
6924 expr += " })";
6925 }
6926 else
6927 {
6928 if (is_array_of_struct)
6929 {
6930 const uint32_t indices[2] = { i, interface_index };
6931 AccessChainMeta meta;
6932 expr += access_chain_internal(stage_in_ptr_var_id, indices, 2,
6933 ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT,
6934 &meta);
6935 }
6936 else
6937 expr += to_expression(ptr) + "." + to_member_name(iface_type, interface_index);
6938 if (expr_mbr_type.vecsize > mbr_type.vecsize)
6939 expr += vector_swizzle(mbr_type.vecsize, 0);
6940 }
6941
6942 if (j + 1 < struct_type.member_types.size())
6943 expr += ", ";
6944 }
6945 expr += " }";
6946 if (i + 1 < num_control_points)
6947 expr += ", ";
6948 }
6949 if (is_array_of_struct)
6950 expr += " })";
6951 }
6952 else if (is_matrix(result_type))
6953 {
6954 bool is_array_of_matrix = is_array(result_type);
6955 if (is_array_of_matrix && !ptr_is_io_variable)
6956 SPIRV_CROSS_THROW("Loading array of matrix from IO variable must come directly from IO variable.");
6957 if (interface_index == uint32_t(-1))
6958 SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
6959
6960 if (is_array_of_matrix)
6961 {
6962 // Loading a matrix from each control point.
6963 uint32_t base_interface_index = interface_index;
6964 uint32_t num_control_points = to_array_size_literal(result_type, 0);
6965 expr += type_to_glsl(result_type) + "({ ";
6966
6967 auto &matrix_type = get_variable_element_type(get<SPIRVariable>(ptr));
6968
6969 for (uint32_t i = 0; i < num_control_points; i++)
6970 {
6971 interface_index = base_interface_index;
6972 expr += type_to_glsl(matrix_type) + "(";
6973 for (uint32_t j = 0; j < result_type.columns; j++, interface_index++)
6974 {
6975 const uint32_t indices[2] = { i, interface_index };
6976
6977 AccessChainMeta meta;
6978 expr += access_chain_internal(stage_in_ptr_var_id, indices, 2,
6979 ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
6980 if (expr_type.vecsize > result_type.vecsize)
6981 expr += vector_swizzle(result_type.vecsize, 0);
6982 if (j + 1 < result_type.columns)
6983 expr += ", ";
6984 }
6985 expr += ")";
6986 if (i + 1 < num_control_points)
6987 expr += ", ";
6988 }
6989
6990 expr += " })";
6991 }
6992 else
6993 {
6994 expr += type_to_glsl(result_type) + "(";
6995 for (uint32_t i = 0; i < result_type.columns; i++, interface_index++)
6996 {
6997 expr += to_expression(ptr) + "." + to_member_name(iface_type, interface_index);
6998 if (expr_type.vecsize > result_type.vecsize)
6999 expr += vector_swizzle(result_type.vecsize, 0);
7000 if (i + 1 < result_type.columns)
7001 expr += ", ";
7002 }
7003 expr += ")";
7004 }
7005 }
7006 else if (ptr_is_io_variable)
7007 {
7008 assert(is_array(result_type));
7009 assert(result_type.array.size() == 1);
7010 if (interface_index == uint32_t(-1))
7011 SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
7012
7013 // We're loading an array directly from a global variable.
7014 // This means we're loading one member from each control point.
7015 expr += type_to_glsl(result_type) + "({ ";
7016 uint32_t num_control_points = to_array_size_literal(result_type, 0);
7017
7018 for (uint32_t i = 0; i < num_control_points; i++)
7019 {
7020 const uint32_t indices[2] = { i, interface_index };
7021
7022 AccessChainMeta meta;
7023 expr += access_chain_internal(stage_in_ptr_var_id, indices, 2,
7024 ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
7025 if (expr_type.vecsize > result_type.vecsize)
7026 expr += vector_swizzle(result_type.vecsize, 0);
7027
7028 if (i + 1 < num_control_points)
7029 expr += ", ";
7030 }
7031 expr += " })";
7032 }
7033 else
7034 {
7035 // We're loading an array from a concrete control point.
7036 assert(is_array(result_type));
7037 assert(result_type.array.size() == 1);
7038 if (interface_index == uint32_t(-1))
7039 SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
7040
7041 expr += type_to_glsl(result_type) + "({ ";
7042 uint32_t array_size = to_array_size_literal(result_type, 0);
7043 for (uint32_t i = 0; i < array_size; i++, interface_index++)
7044 {
7045 expr += to_expression(ptr) + "." + to_member_name(iface_type, interface_index);
7046 if (expr_type.vecsize > result_type.vecsize)
7047 expr += vector_swizzle(result_type.vecsize, 0);
7048 if (i + 1 < array_size)
7049 expr += ", ";
7050 }
7051 expr += " })";
7052 }
7053
7054 emit_op(result_type_id, id, expr, false);
7055 register_read(id, ptr, false);
7056 return true;
7057 }
7058
emit_tessellation_access_chain(const uint32_t * ops,uint32_t length)7059 bool CompilerMSL::emit_tessellation_access_chain(const uint32_t *ops, uint32_t length)
7060 {
7061 // If this is a per-vertex output, remap it to the I/O array buffer.
7062
7063 // Any object which did not go through IO flattening shenanigans will go there instead.
7064 // We will unflatten on-demand instead as needed, but not all possible cases can be supported, especially with arrays.
7065
7066 auto *var = maybe_get_backing_variable(ops[2]);
7067 bool patch = false;
7068 bool flat_data = false;
7069 bool ptr_is_chain = false;
7070 bool flatten_composites = false;
7071
7072 bool is_block = false;
7073
7074 if (var)
7075 is_block = has_decoration(get_variable_data_type(*var).self, DecorationBlock);
7076
7077 if (var)
7078 {
7079 flatten_composites = variable_storage_requires_stage_io(var->storage);
7080 patch = has_decoration(ops[2], DecorationPatch) || is_patch_block(get_variable_data_type(*var));
7081
7082 // Should match strip_array in add_interface_block.
7083 flat_data = var->storage == StorageClassInput ||
7084 (var->storage == StorageClassOutput && get_execution_model() == ExecutionModelTessellationControl);
7085
7086 // Patch inputs are treated as normal block IO variables, so they don't deal with this path at all.
7087 if (patch && (!is_block || var->storage == StorageClassInput))
7088 flat_data = false;
7089
7090 // We might have a chained access chain, where
7091 // we first take the access chain to the control point, and then we chain into a member or something similar.
7092 // In this case, we need to skip gl_in/gl_out remapping.
7093 // Also, skip ptr chain for patches.
7094 ptr_is_chain = var->self != ID(ops[2]);
7095 }
7096
7097 bool builtin_variable = false;
7098 bool variable_is_flat = false;
7099
7100 if (var && flat_data)
7101 {
7102 builtin_variable = is_builtin_variable(*var);
7103
7104 BuiltIn bi_type = BuiltInMax;
7105 if (builtin_variable && !is_block)
7106 bi_type = BuiltIn(get_decoration(var->self, DecorationBuiltIn));
7107
7108 variable_is_flat = !builtin_variable || is_block ||
7109 bi_type == BuiltInPosition || bi_type == BuiltInPointSize ||
7110 bi_type == BuiltInClipDistance || bi_type == BuiltInCullDistance;
7111 }
7112
7113 if (variable_is_flat)
7114 {
7115 // If output is masked, it is emitted as a "normal" variable, just go through normal code paths.
7116 // Only check this for the first level of access chain.
7117 // Dealing with this for partial access chains should be possible, but awkward.
7118 if (var->storage == StorageClassOutput && !ptr_is_chain)
7119 {
7120 bool masked = false;
7121 if (is_block)
7122 {
7123 uint32_t relevant_member_index = patch ? 3 : 4;
7124 // FIXME: This won't work properly if the application first access chains into gl_out element,
7125 // then access chains into the member. Super weird, but theoretically possible ...
7126 if (length > relevant_member_index)
7127 {
7128 uint32_t mbr_idx = get<SPIRConstant>(ops[relevant_member_index]).scalar();
7129 masked = is_stage_output_block_member_masked(*var, mbr_idx, true);
7130 }
7131 }
7132 else if (var)
7133 masked = is_stage_output_variable_masked(*var);
7134
7135 if (masked)
7136 return false;
7137 }
7138
7139 AccessChainMeta meta;
7140 SmallVector<uint32_t> indices;
7141 uint32_t next_id = ir.increase_bound_by(1);
7142
7143 indices.reserve(length - 3 + 1);
7144
7145 uint32_t first_non_array_index = (ptr_is_chain ? 3 : 4) - (patch ? 1 : 0);
7146
7147 VariableID stage_var_id;
7148 if (patch)
7149 stage_var_id = var->storage == StorageClassInput ? patch_stage_in_var_id : patch_stage_out_var_id;
7150 else
7151 stage_var_id = var->storage == StorageClassInput ? stage_in_ptr_var_id : stage_out_ptr_var_id;
7152
7153 VariableID ptr = ptr_is_chain ? VariableID(ops[2]) : stage_var_id;
7154 if (!ptr_is_chain && !patch)
7155 {
7156 // Index into gl_in/gl_out with first array index.
7157 indices.push_back(ops[first_non_array_index - 1]);
7158 }
7159
7160 auto &result_ptr_type = get<SPIRType>(ops[0]);
7161
7162 uint32_t const_mbr_id = next_id++;
7163 uint32_t index = get_extended_decoration(ops[2], SPIRVCrossDecorationInterfaceMemberIndex);
7164
7165 // If we have a pointer chain expression, and we are no longer pointing to a composite
7166 // object, we are in the clear. There is no longer a need to flatten anything.
7167 bool further_access_chain_is_trivial = false;
7168 if (ptr_is_chain && flatten_composites)
7169 {
7170 auto &ptr_type = expression_type(ptr);
7171 if (!is_array(ptr_type) && !is_matrix(ptr_type) && ptr_type.basetype != SPIRType::Struct)
7172 further_access_chain_is_trivial = true;
7173 }
7174
7175 if (!further_access_chain_is_trivial && (flatten_composites || is_block))
7176 {
7177 uint32_t i = first_non_array_index;
7178 auto *type = &get_variable_element_type(*var);
7179 if (index == uint32_t(-1) && length >= (first_non_array_index + 1))
7180 {
7181 // Maybe this is a struct type in the input class, in which case
7182 // we put it as a decoration on the corresponding member.
7183 uint32_t mbr_idx = get_constant(ops[first_non_array_index]).scalar();
7184 index = get_extended_member_decoration(var->self, mbr_idx,
7185 SPIRVCrossDecorationInterfaceMemberIndex);
7186 assert(index != uint32_t(-1));
7187 i++;
7188 type = &get<SPIRType>(type->member_types[mbr_idx]);
7189 }
7190
7191 // In this case, we're poking into flattened structures and arrays, so now we have to
7192 // combine the following indices. If we encounter a non-constant index,
7193 // we're hosed.
7194 for (; flatten_composites && i < length; ++i)
7195 {
7196 if (!is_array(*type) && !is_matrix(*type) && type->basetype != SPIRType::Struct)
7197 break;
7198
7199 auto *c = maybe_get<SPIRConstant>(ops[i]);
7200 if (!c || c->specialization)
7201 SPIRV_CROSS_THROW("Trying to dynamically index into an array interface variable in tessellation. "
7202 "This is currently unsupported.");
7203
7204 // We're in flattened space, so just increment the member index into IO block.
7205 // We can only do this once in the current implementation, so either:
7206 // Struct, Matrix or 1-dimensional array for a control point.
7207 if (type->basetype == SPIRType::Struct && var->storage == StorageClassOutput)
7208 {
7209 // Need to consider holes, since individual block members might be masked away.
7210 uint32_t mbr_idx = c->scalar();
7211 for (uint32_t j = 0; j < mbr_idx; j++)
7212 if (!is_stage_output_block_member_masked(*var, j, true))
7213 index++;
7214 }
7215 else
7216 index += c->scalar();
7217
7218 if (type->parent_type)
7219 type = &get<SPIRType>(type->parent_type);
7220 else if (type->basetype == SPIRType::Struct)
7221 type = &get<SPIRType>(type->member_types[c->scalar()]);
7222 }
7223
7224 // We're not going to emit the actual member name, we let any further OpLoad take care of that.
7225 // Tag the access chain with the member index we're referencing.
7226 bool defer_access_chain = flatten_composites && (is_matrix(result_ptr_type) || is_array(result_ptr_type) ||
7227 result_ptr_type.basetype == SPIRType::Struct);
7228
7229 if (!defer_access_chain)
7230 {
7231 // Access the appropriate member of gl_in/gl_out.
7232 set<SPIRConstant>(const_mbr_id, get_uint_type_id(), index, false);
7233 indices.push_back(const_mbr_id);
7234
7235 // Member index is now irrelevant.
7236 index = uint32_t(-1);
7237
7238 // Append any straggling access chain indices.
7239 if (i < length)
7240 indices.insert(indices.end(), ops + i, ops + length);
7241 }
7242 else
7243 {
7244 // We must have consumed the entire access chain if we're deferring it.
7245 assert(i == length);
7246 }
7247
7248 if (index != uint32_t(-1))
7249 set_extended_decoration(ops[1], SPIRVCrossDecorationInterfaceMemberIndex, index);
7250 else
7251 unset_extended_decoration(ops[1], SPIRVCrossDecorationInterfaceMemberIndex);
7252 }
7253 else
7254 {
7255 if (index != uint32_t(-1))
7256 {
7257 set<SPIRConstant>(const_mbr_id, get_uint_type_id(), index, false);
7258 indices.push_back(const_mbr_id);
7259 }
7260
7261 // Member index is now irrelevant.
7262 index = uint32_t(-1);
7263 unset_extended_decoration(ops[1], SPIRVCrossDecorationInterfaceMemberIndex);
7264
7265 indices.insert(indices.end(), ops + first_non_array_index, ops + length);
7266 }
7267
7268 // We use the pointer to the base of the input/output array here,
7269 // so this is always a pointer chain.
7270 string e;
7271
7272 if (!ptr_is_chain)
7273 {
7274 // This is the start of an access chain, use ptr_chain to index into control point array.
7275 e = access_chain(ptr, indices.data(), uint32_t(indices.size()), result_ptr_type, &meta, !patch);
7276 }
7277 else
7278 {
7279 // If we're accessing a struct, we need to use member indices which are based on the IO block,
7280 // not actual struct type, so we have to use a split access chain here where
7281 // first path resolves the control point index, i.e. gl_in[index], and second half deals with
7282 // looking up flattened member name.
7283
7284 // However, it is possible that we partially accessed a struct,
7285 // by taking pointer to member inside the control-point array.
7286 // For this case, we fall back to a natural access chain since we have already dealt with remapping struct members.
7287 // One way to check this here is if we have 2 implied read expressions.
7288 // First one is the gl_in/gl_out struct itself, then an index into that array.
7289 // If we have traversed further, we use a normal access chain formulation.
7290 auto *ptr_expr = maybe_get<SPIRExpression>(ptr);
7291 bool split_access_chain_formulation = flatten_composites && ptr_expr &&
7292 ptr_expr->implied_read_expressions.size() == 2 &&
7293 !further_access_chain_is_trivial;
7294
7295 if (split_access_chain_formulation)
7296 {
7297 e = join(to_expression(ptr),
7298 access_chain_internal(stage_var_id, indices.data(), uint32_t(indices.size()),
7299 ACCESS_CHAIN_CHAIN_ONLY_BIT, &meta));
7300 }
7301 else
7302 {
7303 e = access_chain_internal(ptr, indices.data(), uint32_t(indices.size()), 0, &meta);
7304 }
7305 }
7306
7307 // Get the actual type of the object that was accessed. If it's a vector type and we changed it,
7308 // then we'll need to add a swizzle.
7309 // For this, we can't necessarily rely on the type of the base expression, because it might be
7310 // another access chain, and it will therefore already have the "correct" type.
7311 auto *expr_type = &get_variable_data_type(*var);
7312 if (has_extended_decoration(ops[2], SPIRVCrossDecorationTessIOOriginalInputTypeID))
7313 expr_type = &get<SPIRType>(get_extended_decoration(ops[2], SPIRVCrossDecorationTessIOOriginalInputTypeID));
7314 for (uint32_t i = 3; i < length; i++)
7315 {
7316 if (!is_array(*expr_type) && expr_type->basetype == SPIRType::Struct)
7317 expr_type = &get<SPIRType>(expr_type->member_types[get<SPIRConstant>(ops[i]).scalar()]);
7318 else
7319 expr_type = &get<SPIRType>(expr_type->parent_type);
7320 }
7321 if (!is_array(*expr_type) && !is_matrix(*expr_type) && expr_type->basetype != SPIRType::Struct &&
7322 expr_type->vecsize > result_ptr_type.vecsize)
7323 e += vector_swizzle(result_ptr_type.vecsize, 0);
7324
7325 auto &expr = set<SPIRExpression>(ops[1], move(e), ops[0], should_forward(ops[2]));
7326 expr.loaded_from = var->self;
7327 expr.need_transpose = meta.need_transpose;
7328 expr.access_chain = true;
7329
7330 // Mark the result as being packed if necessary.
7331 if (meta.storage_is_packed)
7332 set_extended_decoration(ops[1], SPIRVCrossDecorationPhysicalTypePacked);
7333 if (meta.storage_physical_type != 0)
7334 set_extended_decoration(ops[1], SPIRVCrossDecorationPhysicalTypeID, meta.storage_physical_type);
7335 if (meta.storage_is_invariant)
7336 set_decoration(ops[1], DecorationInvariant);
7337 // Save the type we found in case the result is used in another access chain.
7338 set_extended_decoration(ops[1], SPIRVCrossDecorationTessIOOriginalInputTypeID, expr_type->self);
7339
7340 // If we have some expression dependencies in our access chain, this access chain is technically a forwarded
7341 // temporary which could be subject to invalidation.
7342 // Need to assume we're forwarded while calling inherit_expression_depdendencies.
7343 forwarded_temporaries.insert(ops[1]);
7344 // The access chain itself is never forced to a temporary, but its dependencies might.
7345 suppressed_usage_tracking.insert(ops[1]);
7346
7347 for (uint32_t i = 2; i < length; i++)
7348 {
7349 inherit_expression_dependencies(ops[1], ops[i]);
7350 add_implied_read_expression(expr, ops[i]);
7351 }
7352
7353 // If we have no dependencies after all, i.e., all indices in the access chain are immutable temporaries,
7354 // we're not forwarded after all.
7355 if (expr.expression_dependencies.empty())
7356 forwarded_temporaries.erase(ops[1]);
7357
7358 return true;
7359 }
7360
7361 // If this is the inner tessellation level, and we're tessellating triangles,
7362 // drop the last index. It isn't an array in this case, so we can't have an
7363 // array reference here. We need to make this ID a variable instead of an
7364 // expression so we don't try to dereference it as a variable pointer.
7365 // Don't do this if the index is a constant 1, though. We need to drop stores
7366 // to that one.
7367 auto *m = ir.find_meta(var ? var->self : ID(0));
7368 if (get_execution_model() == ExecutionModelTessellationControl && var && m &&
7369 m->decoration.builtin_type == BuiltInTessLevelInner && get_entry_point().flags.get(ExecutionModeTriangles))
7370 {
7371 auto *c = maybe_get<SPIRConstant>(ops[3]);
7372 if (c && c->scalar() == 1)
7373 return false;
7374 auto &dest_var = set<SPIRVariable>(ops[1], *var);
7375 dest_var.basetype = ops[0];
7376 ir.meta[ops[1]] = ir.meta[ops[2]];
7377 inherit_expression_dependencies(ops[1], ops[2]);
7378 return true;
7379 }
7380
7381 return false;
7382 }
7383
is_out_of_bounds_tessellation_level(uint32_t id_lhs)7384 bool CompilerMSL::is_out_of_bounds_tessellation_level(uint32_t id_lhs)
7385 {
7386 if (!get_entry_point().flags.get(ExecutionModeTriangles))
7387 return false;
7388
7389 // In SPIR-V, TessLevelInner always has two elements and TessLevelOuter always has
7390 // four. This is true even if we are tessellating triangles. This allows clients
7391 // to use a single tessellation control shader with multiple tessellation evaluation
7392 // shaders.
7393 // In Metal, however, only the first element of TessLevelInner and the first three
7394 // of TessLevelOuter are accessible. This stems from how in Metal, the tessellation
7395 // levels must be stored to a dedicated buffer in a particular format that depends
7396 // on the patch type. Therefore, in Triangles mode, any access to the second
7397 // inner level or the fourth outer level must be dropped.
7398 const auto *e = maybe_get<SPIRExpression>(id_lhs);
7399 if (!e || !e->access_chain)
7400 return false;
7401 BuiltIn builtin = BuiltIn(get_decoration(e->loaded_from, DecorationBuiltIn));
7402 if (builtin != BuiltInTessLevelInner && builtin != BuiltInTessLevelOuter)
7403 return false;
7404 auto *c = maybe_get<SPIRConstant>(e->implied_read_expressions[1]);
7405 if (!c)
7406 return false;
7407 return (builtin == BuiltInTessLevelInner && c->scalar() == 1) ||
7408 (builtin == BuiltInTessLevelOuter && c->scalar() == 3);
7409 }
7410
prepare_access_chain_for_scalar_access(std::string & expr,const SPIRType & type,spv::StorageClass storage,bool & is_packed)7411 void CompilerMSL::prepare_access_chain_for_scalar_access(std::string &expr, const SPIRType &type,
7412 spv::StorageClass storage, bool &is_packed)
7413 {
7414 // If there is any risk of writes happening with the access chain in question,
7415 // and there is a risk of concurrent write access to other components,
7416 // we must cast the access chain to a plain pointer to ensure we only access the exact scalars we expect.
7417 // The MSL compiler refuses to allow component-level access for any non-packed vector types.
7418 if (!is_packed && (storage == StorageClassStorageBuffer || storage == StorageClassWorkgroup))
7419 {
7420 const char *addr_space = storage == StorageClassWorkgroup ? "threadgroup" : "device";
7421 expr = join("((", addr_space, " ", type_to_glsl(type), "*)&", enclose_expression(expr), ")");
7422
7423 // Further indexing should happen with packed rules (array index, not swizzle).
7424 is_packed = true;
7425 }
7426 }
7427
access_chain_needs_stage_io_builtin_translation(uint32_t base)7428 bool CompilerMSL::access_chain_needs_stage_io_builtin_translation(uint32_t base)
7429 {
7430 auto *var = maybe_get_backing_variable(base);
7431 if (!var || !is_tessellation_shader())
7432 return true;
7433
7434 // We only need to rewrite builtin access chains when accessing flattened builtins like gl_ClipDistance_N.
7435 // Avoid overriding it back to just gl_ClipDistance.
7436 // This can only happen in scenarios where we cannot flatten/unflatten access chains, so, the only case
7437 // where this triggers is evaluation shader inputs.
7438 bool redirect_builtin = get_execution_model() == ExecutionModelTessellationEvaluation ?
7439 var->storage == StorageClassOutput : false;
7440 return redirect_builtin;
7441 }
7442
7443 // Sets the interface member index for an access chain to a pull-model interpolant.
fix_up_interpolant_access_chain(const uint32_t * ops,uint32_t length)7444 void CompilerMSL::fix_up_interpolant_access_chain(const uint32_t *ops, uint32_t length)
7445 {
7446 auto *var = maybe_get_backing_variable(ops[2]);
7447 if (!var || !pull_model_inputs.count(var->self))
7448 return;
7449 // Get the base index.
7450 uint32_t interface_index;
7451 auto &var_type = get_variable_data_type(*var);
7452 auto &result_type = get<SPIRType>(ops[0]);
7453 auto *type = &var_type;
7454 if (has_extended_decoration(ops[2], SPIRVCrossDecorationInterfaceMemberIndex))
7455 {
7456 interface_index = get_extended_decoration(ops[2], SPIRVCrossDecorationInterfaceMemberIndex);
7457 }
7458 else
7459 {
7460 // Assume an access chain into a struct variable.
7461 assert(var_type.basetype == SPIRType::Struct);
7462 auto &c = get<SPIRConstant>(ops[3 + var_type.array.size()]);
7463 interface_index =
7464 get_extended_member_decoration(var->self, c.scalar(), SPIRVCrossDecorationInterfaceMemberIndex);
7465 }
7466 // Accumulate indices. We'll have to skip over the one for the struct, if present, because we already accounted
7467 // for that getting the base index.
7468 for (uint32_t i = 3; i < length; ++i)
7469 {
7470 if (is_vector(*type) && !is_array(*type) && is_scalar(result_type))
7471 {
7472 // We don't want to combine the next index. Actually, we need to save it
7473 // so we know to apply a swizzle to the result of the interpolation.
7474 set_extended_decoration(ops[1], SPIRVCrossDecorationInterpolantComponentExpr, ops[i]);
7475 break;
7476 }
7477
7478 auto *c = maybe_get<SPIRConstant>(ops[i]);
7479 if (!c || c->specialization)
7480 SPIRV_CROSS_THROW("Trying to dynamically index into an array interface variable using pull-model "
7481 "interpolation. This is currently unsupported.");
7482
7483 if (type->parent_type)
7484 type = &get<SPIRType>(type->parent_type);
7485 else if (type->basetype == SPIRType::Struct)
7486 type = &get<SPIRType>(type->member_types[c->scalar()]);
7487
7488 if (!has_extended_decoration(ops[2], SPIRVCrossDecorationInterfaceMemberIndex) &&
7489 i - 3 == var_type.array.size())
7490 continue;
7491
7492 interface_index += c->scalar();
7493 }
7494 // Save this to the access chain itself so we can recover it later when calling an interpolation function.
7495 set_extended_decoration(ops[1], SPIRVCrossDecorationInterfaceMemberIndex, interface_index);
7496 }
7497
7498 // Override for MSL-specific syntax instructions
emit_instruction(const Instruction & instruction)7499 void CompilerMSL::emit_instruction(const Instruction &instruction)
7500 {
7501 #define MSL_BOP(op) emit_binary_op(ops[0], ops[1], ops[2], ops[3], #op)
7502 #define MSL_BOP_CAST(op, type) \
7503 emit_binary_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode))
7504 #define MSL_UOP(op) emit_unary_op(ops[0], ops[1], ops[2], #op)
7505 #define MSL_QFOP(op) emit_quaternary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], ops[5], #op)
7506 #define MSL_TFOP(op) emit_trinary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], #op)
7507 #define MSL_BFOP(op) emit_binary_func_op(ops[0], ops[1], ops[2], ops[3], #op)
7508 #define MSL_BFOP_CAST(op, type) \
7509 emit_binary_func_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode))
7510 #define MSL_UFOP(op) emit_unary_func_op(ops[0], ops[1], ops[2], #op)
7511 #define MSL_UNORD_BOP(op) emit_binary_unord_op(ops[0], ops[1], ops[2], ops[3], #op)
7512
7513 auto ops = stream(instruction);
7514 auto opcode = static_cast<Op>(instruction.op);
7515
7516 // If we need to do implicit bitcasts, make sure we do it with the correct type.
7517 uint32_t integer_width = get_integer_width_for_instruction(instruction);
7518 auto int_type = to_signed_basetype(integer_width);
7519 auto uint_type = to_unsigned_basetype(integer_width);
7520
7521 switch (opcode)
7522 {
7523 case OpLoad:
7524 {
7525 uint32_t id = ops[1];
7526 uint32_t ptr = ops[2];
7527 if (is_tessellation_shader())
7528 {
7529 if (!emit_tessellation_io_load(ops[0], id, ptr))
7530 CompilerGLSL::emit_instruction(instruction);
7531 }
7532 else
7533 {
7534 // Sample mask input for Metal is not an array
7535 if (BuiltIn(get_decoration(ptr, DecorationBuiltIn)) == BuiltInSampleMask)
7536 set_decoration(id, DecorationBuiltIn, BuiltInSampleMask);
7537 CompilerGLSL::emit_instruction(instruction);
7538 }
7539 break;
7540 }
7541
7542 // Comparisons
7543 case OpIEqual:
7544 MSL_BOP_CAST(==, int_type);
7545 break;
7546
7547 case OpLogicalEqual:
7548 case OpFOrdEqual:
7549 MSL_BOP(==);
7550 break;
7551
7552 case OpINotEqual:
7553 MSL_BOP_CAST(!=, int_type);
7554 break;
7555
7556 case OpLogicalNotEqual:
7557 case OpFOrdNotEqual:
7558 MSL_BOP(!=);
7559 break;
7560
7561 case OpUGreaterThan:
7562 MSL_BOP_CAST(>, uint_type);
7563 break;
7564
7565 case OpSGreaterThan:
7566 MSL_BOP_CAST(>, int_type);
7567 break;
7568
7569 case OpFOrdGreaterThan:
7570 MSL_BOP(>);
7571 break;
7572
7573 case OpUGreaterThanEqual:
7574 MSL_BOP_CAST(>=, uint_type);
7575 break;
7576
7577 case OpSGreaterThanEqual:
7578 MSL_BOP_CAST(>=, int_type);
7579 break;
7580
7581 case OpFOrdGreaterThanEqual:
7582 MSL_BOP(>=);
7583 break;
7584
7585 case OpULessThan:
7586 MSL_BOP_CAST(<, uint_type);
7587 break;
7588
7589 case OpSLessThan:
7590 MSL_BOP_CAST(<, int_type);
7591 break;
7592
7593 case OpFOrdLessThan:
7594 MSL_BOP(<);
7595 break;
7596
7597 case OpULessThanEqual:
7598 MSL_BOP_CAST(<=, uint_type);
7599 break;
7600
7601 case OpSLessThanEqual:
7602 MSL_BOP_CAST(<=, int_type);
7603 break;
7604
7605 case OpFOrdLessThanEqual:
7606 MSL_BOP(<=);
7607 break;
7608
7609 case OpFUnordEqual:
7610 MSL_UNORD_BOP(==);
7611 break;
7612
7613 case OpFUnordNotEqual:
7614 MSL_UNORD_BOP(!=);
7615 break;
7616
7617 case OpFUnordGreaterThan:
7618 MSL_UNORD_BOP(>);
7619 break;
7620
7621 case OpFUnordGreaterThanEqual:
7622 MSL_UNORD_BOP(>=);
7623 break;
7624
7625 case OpFUnordLessThan:
7626 MSL_UNORD_BOP(<);
7627 break;
7628
7629 case OpFUnordLessThanEqual:
7630 MSL_UNORD_BOP(<=);
7631 break;
7632
7633 // Derivatives
7634 case OpDPdx:
7635 case OpDPdxFine:
7636 case OpDPdxCoarse:
7637 MSL_UFOP(dfdx);
7638 register_control_dependent_expression(ops[1]);
7639 break;
7640
7641 case OpDPdy:
7642 case OpDPdyFine:
7643 case OpDPdyCoarse:
7644 MSL_UFOP(dfdy);
7645 register_control_dependent_expression(ops[1]);
7646 break;
7647
7648 case OpFwidth:
7649 case OpFwidthCoarse:
7650 case OpFwidthFine:
7651 MSL_UFOP(fwidth);
7652 register_control_dependent_expression(ops[1]);
7653 break;
7654
7655 // Bitfield
7656 case OpBitFieldInsert:
7657 {
7658 emit_bitfield_insert_op(ops[0], ops[1], ops[2], ops[3], ops[4], ops[5], "insert_bits", SPIRType::UInt);
7659 break;
7660 }
7661
7662 case OpBitFieldSExtract:
7663 {
7664 emit_trinary_func_op_bitextract(ops[0], ops[1], ops[2], ops[3], ops[4], "extract_bits", int_type, int_type,
7665 SPIRType::UInt, SPIRType::UInt);
7666 break;
7667 }
7668
7669 case OpBitFieldUExtract:
7670 {
7671 emit_trinary_func_op_bitextract(ops[0], ops[1], ops[2], ops[3], ops[4], "extract_bits", uint_type, uint_type,
7672 SPIRType::UInt, SPIRType::UInt);
7673 break;
7674 }
7675
7676 case OpBitReverse:
7677 // BitReverse does not have issues with sign since result type must match input type.
7678 MSL_UFOP(reverse_bits);
7679 break;
7680
7681 case OpBitCount:
7682 {
7683 auto basetype = expression_type(ops[2]).basetype;
7684 emit_unary_func_op_cast(ops[0], ops[1], ops[2], "popcount", basetype, basetype);
7685 break;
7686 }
7687
7688 case OpFRem:
7689 MSL_BFOP(fmod);
7690 break;
7691
7692 case OpFMul:
7693 if (msl_options.invariant_float_math || has_decoration(ops[1], DecorationNoContraction))
7694 MSL_BFOP(spvFMul);
7695 else
7696 MSL_BOP(*);
7697 break;
7698
7699 case OpFAdd:
7700 if (msl_options.invariant_float_math || has_decoration(ops[1], DecorationNoContraction))
7701 MSL_BFOP(spvFAdd);
7702 else
7703 MSL_BOP(+);
7704 break;
7705
7706 case OpFSub:
7707 if (msl_options.invariant_float_math || has_decoration(ops[1], DecorationNoContraction))
7708 MSL_BFOP(spvFSub);
7709 else
7710 MSL_BOP(-);
7711 break;
7712
7713 // Atomics
7714 case OpAtomicExchange:
7715 {
7716 uint32_t result_type = ops[0];
7717 uint32_t id = ops[1];
7718 uint32_t ptr = ops[2];
7719 uint32_t mem_sem = ops[4];
7720 uint32_t val = ops[5];
7721 emit_atomic_func_op(result_type, id, "atomic_exchange_explicit", mem_sem, mem_sem, false, ptr, val);
7722 break;
7723 }
7724
7725 case OpAtomicCompareExchange:
7726 {
7727 uint32_t result_type = ops[0];
7728 uint32_t id = ops[1];
7729 uint32_t ptr = ops[2];
7730 uint32_t mem_sem_pass = ops[4];
7731 uint32_t mem_sem_fail = ops[5];
7732 uint32_t val = ops[6];
7733 uint32_t comp = ops[7];
7734 emit_atomic_func_op(result_type, id, "atomic_compare_exchange_weak_explicit", mem_sem_pass, mem_sem_fail, true,
7735 ptr, comp, true, false, val);
7736 break;
7737 }
7738
7739 case OpAtomicCompareExchangeWeak:
7740 SPIRV_CROSS_THROW("OpAtomicCompareExchangeWeak is only supported in kernel profile.");
7741
7742 case OpAtomicLoad:
7743 {
7744 uint32_t result_type = ops[0];
7745 uint32_t id = ops[1];
7746 uint32_t ptr = ops[2];
7747 uint32_t mem_sem = ops[4];
7748 emit_atomic_func_op(result_type, id, "atomic_load_explicit", mem_sem, mem_sem, false, ptr, 0);
7749 break;
7750 }
7751
7752 case OpAtomicStore:
7753 {
7754 uint32_t result_type = expression_type(ops[0]).self;
7755 uint32_t id = ops[0];
7756 uint32_t ptr = ops[0];
7757 uint32_t mem_sem = ops[2];
7758 uint32_t val = ops[3];
7759 emit_atomic_func_op(result_type, id, "atomic_store_explicit", mem_sem, mem_sem, false, ptr, val);
7760 break;
7761 }
7762
7763 #define MSL_AFMO_IMPL(op, valsrc, valconst) \
7764 do \
7765 { \
7766 uint32_t result_type = ops[0]; \
7767 uint32_t id = ops[1]; \
7768 uint32_t ptr = ops[2]; \
7769 uint32_t mem_sem = ops[4]; \
7770 uint32_t val = valsrc; \
7771 emit_atomic_func_op(result_type, id, "atomic_fetch_" #op "_explicit", mem_sem, mem_sem, false, ptr, val, \
7772 false, valconst); \
7773 } while (false)
7774
7775 #define MSL_AFMO(op) MSL_AFMO_IMPL(op, ops[5], false)
7776 #define MSL_AFMIO(op) MSL_AFMO_IMPL(op, 1, true)
7777
7778 case OpAtomicIIncrement:
7779 MSL_AFMIO(add);
7780 break;
7781
7782 case OpAtomicIDecrement:
7783 MSL_AFMIO(sub);
7784 break;
7785
7786 case OpAtomicIAdd:
7787 MSL_AFMO(add);
7788 break;
7789
7790 case OpAtomicISub:
7791 MSL_AFMO(sub);
7792 break;
7793
7794 case OpAtomicSMin:
7795 case OpAtomicUMin:
7796 MSL_AFMO(min);
7797 break;
7798
7799 case OpAtomicSMax:
7800 case OpAtomicUMax:
7801 MSL_AFMO(max);
7802 break;
7803
7804 case OpAtomicAnd:
7805 MSL_AFMO(and);
7806 break;
7807
7808 case OpAtomicOr:
7809 MSL_AFMO(or);
7810 break;
7811
7812 case OpAtomicXor:
7813 MSL_AFMO(xor);
7814 break;
7815
7816 // Images
7817
7818 // Reads == Fetches in Metal
7819 case OpImageRead:
7820 {
7821 // Mark that this shader reads from this image
7822 uint32_t img_id = ops[2];
7823 auto &type = expression_type(img_id);
7824 if (type.image.dim != DimSubpassData)
7825 {
7826 auto *p_var = maybe_get_backing_variable(img_id);
7827 if (p_var && has_decoration(p_var->self, DecorationNonReadable))
7828 {
7829 unset_decoration(p_var->self, DecorationNonReadable);
7830 force_recompile();
7831 }
7832 }
7833
7834 emit_texture_op(instruction, false);
7835 break;
7836 }
7837
7838 // Emulate texture2D atomic operations
7839 case OpImageTexelPointer:
7840 {
7841 // When using the pointer, we need to know which variable it is actually loaded from.
7842 auto *var = maybe_get_backing_variable(ops[2]);
7843 if (var && atomic_image_vars.count(var->self))
7844 {
7845 uint32_t result_type = ops[0];
7846 uint32_t id = ops[1];
7847
7848 std::string coord = to_expression(ops[3]);
7849 auto &type = expression_type(ops[2]);
7850 if (type.image.dim == Dim2D)
7851 {
7852 coord = join("spvImage2DAtomicCoord(", coord, ", ", to_expression(ops[2]), ")");
7853 }
7854
7855 auto &e = set<SPIRExpression>(id, join(to_expression(ops[2]), "_atomic[", coord, "]"), result_type, true);
7856 e.loaded_from = var ? var->self : ID(0);
7857 inherit_expression_dependencies(id, ops[3]);
7858 }
7859 else
7860 {
7861 uint32_t result_type = ops[0];
7862 uint32_t id = ops[1];
7863 auto &e =
7864 set<SPIRExpression>(id, join(to_expression(ops[2]), ", ", to_expression(ops[3])), result_type, true);
7865
7866 // When using the pointer, we need to know which variable it is actually loaded from.
7867 e.loaded_from = var ? var->self : ID(0);
7868 inherit_expression_dependencies(id, ops[3]);
7869 }
7870 break;
7871 }
7872
7873 case OpImageWrite:
7874 {
7875 uint32_t img_id = ops[0];
7876 uint32_t coord_id = ops[1];
7877 uint32_t texel_id = ops[2];
7878 const uint32_t *opt = &ops[3];
7879 uint32_t length = instruction.length - 3;
7880
7881 // Bypass pointers because we need the real image struct
7882 auto &type = expression_type(img_id);
7883 auto &img_type = get<SPIRType>(type.self);
7884
7885 // Ensure this image has been marked as being written to and force a
7886 // recommpile so that the image type output will include write access
7887 auto *p_var = maybe_get_backing_variable(img_id);
7888 if (p_var && has_decoration(p_var->self, DecorationNonWritable))
7889 {
7890 unset_decoration(p_var->self, DecorationNonWritable);
7891 force_recompile();
7892 }
7893
7894 bool forward = false;
7895 uint32_t bias = 0;
7896 uint32_t lod = 0;
7897 uint32_t flags = 0;
7898
7899 if (length)
7900 {
7901 flags = *opt++;
7902 length--;
7903 }
7904
7905 auto test = [&](uint32_t &v, uint32_t flag) {
7906 if (length && (flags & flag))
7907 {
7908 v = *opt++;
7909 length--;
7910 }
7911 };
7912
7913 test(bias, ImageOperandsBiasMask);
7914 test(lod, ImageOperandsLodMask);
7915
7916 auto &texel_type = expression_type(texel_id);
7917 auto store_type = texel_type;
7918 store_type.vecsize = 4;
7919
7920 TextureFunctionArguments args = {};
7921 args.base.img = img_id;
7922 args.base.imgtype = &img_type;
7923 args.base.is_fetch = true;
7924 args.coord = coord_id;
7925 args.lod = lod;
7926 statement(join(to_expression(img_id), ".write(",
7927 remap_swizzle(store_type, texel_type.vecsize, to_expression(texel_id)), ", ",
7928 CompilerMSL::to_function_args(args, &forward), ");"));
7929
7930 if (p_var && variable_storage_is_aliased(*p_var))
7931 flush_all_aliased_variables();
7932
7933 break;
7934 }
7935
7936 case OpImageQuerySize:
7937 case OpImageQuerySizeLod:
7938 {
7939 uint32_t rslt_type_id = ops[0];
7940 auto &rslt_type = get<SPIRType>(rslt_type_id);
7941
7942 uint32_t id = ops[1];
7943
7944 uint32_t img_id = ops[2];
7945 string img_exp = to_expression(img_id);
7946 auto &img_type = expression_type(img_id);
7947 Dim img_dim = img_type.image.dim;
7948 bool img_is_array = img_type.image.arrayed;
7949
7950 if (img_type.basetype != SPIRType::Image)
7951 SPIRV_CROSS_THROW("Invalid type for OpImageQuerySize.");
7952
7953 string lod;
7954 if (opcode == OpImageQuerySizeLod)
7955 {
7956 // LOD index defaults to zero, so don't bother outputing level zero index
7957 string decl_lod = to_expression(ops[3]);
7958 if (decl_lod != "0")
7959 lod = decl_lod;
7960 }
7961
7962 string expr = type_to_glsl(rslt_type) + "(";
7963 expr += img_exp + ".get_width(" + lod + ")";
7964
7965 if (img_dim == Dim2D || img_dim == DimCube || img_dim == Dim3D)
7966 expr += ", " + img_exp + ".get_height(" + lod + ")";
7967
7968 if (img_dim == Dim3D)
7969 expr += ", " + img_exp + ".get_depth(" + lod + ")";
7970
7971 if (img_is_array)
7972 {
7973 expr += ", " + img_exp + ".get_array_size()";
7974 if (img_dim == DimCube && msl_options.emulate_cube_array)
7975 expr += " / 6";
7976 }
7977
7978 expr += ")";
7979
7980 emit_op(rslt_type_id, id, expr, should_forward(img_id));
7981
7982 break;
7983 }
7984
7985 case OpImageQueryLod:
7986 {
7987 if (!msl_options.supports_msl_version(2, 2))
7988 SPIRV_CROSS_THROW("ImageQueryLod is only supported on MSL 2.2 and up.");
7989 uint32_t result_type = ops[0];
7990 uint32_t id = ops[1];
7991 uint32_t image_id = ops[2];
7992 uint32_t coord_id = ops[3];
7993 emit_uninitialized_temporary_expression(result_type, id);
7994
7995 auto sampler_expr = to_sampler_expression(image_id);
7996 auto *combined = maybe_get<SPIRCombinedImageSampler>(image_id);
7997 auto image_expr = combined ? to_expression(combined->image) : to_expression(image_id);
7998
7999 // TODO: It is unclear if calculcate_clamped_lod also conditionally rounds
8000 // the reported LOD based on the sampler. NEAREST miplevel should
8001 // round the LOD, but LINEAR miplevel should not round.
8002 // Let's hope this does not become an issue ...
8003 statement(to_expression(id), ".x = ", image_expr, ".calculate_clamped_lod(", sampler_expr, ", ",
8004 to_expression(coord_id), ");");
8005 statement(to_expression(id), ".y = ", image_expr, ".calculate_unclamped_lod(", sampler_expr, ", ",
8006 to_expression(coord_id), ");");
8007 register_control_dependent_expression(id);
8008 break;
8009 }
8010
8011 #define MSL_ImgQry(qrytype) \
8012 do \
8013 { \
8014 uint32_t rslt_type_id = ops[0]; \
8015 auto &rslt_type = get<SPIRType>(rslt_type_id); \
8016 uint32_t id = ops[1]; \
8017 uint32_t img_id = ops[2]; \
8018 string img_exp = to_expression(img_id); \
8019 string expr = type_to_glsl(rslt_type) + "(" + img_exp + ".get_num_" #qrytype "())"; \
8020 emit_op(rslt_type_id, id, expr, should_forward(img_id)); \
8021 } while (false)
8022
8023 case OpImageQueryLevels:
8024 MSL_ImgQry(mip_levels);
8025 break;
8026
8027 case OpImageQuerySamples:
8028 MSL_ImgQry(samples);
8029 break;
8030
8031 case OpImage:
8032 {
8033 uint32_t result_type = ops[0];
8034 uint32_t id = ops[1];
8035 auto *combined = maybe_get<SPIRCombinedImageSampler>(ops[2]);
8036
8037 if (combined)
8038 {
8039 auto &e = emit_op(result_type, id, to_expression(combined->image), true, true);
8040 auto *var = maybe_get_backing_variable(combined->image);
8041 if (var)
8042 e.loaded_from = var->self;
8043 }
8044 else
8045 {
8046 auto *var = maybe_get_backing_variable(ops[2]);
8047 SPIRExpression *e;
8048 if (var && has_extended_decoration(var->self, SPIRVCrossDecorationDynamicImageSampler))
8049 e = &emit_op(result_type, id, join(to_expression(ops[2]), ".plane0"), true, true);
8050 else
8051 e = &emit_op(result_type, id, to_expression(ops[2]), true, true);
8052 if (var)
8053 e->loaded_from = var->self;
8054 }
8055 break;
8056 }
8057
8058 // Casting
8059 case OpQuantizeToF16:
8060 {
8061 uint32_t result_type = ops[0];
8062 uint32_t id = ops[1];
8063 uint32_t arg = ops[2];
8064
8065 string exp;
8066 auto &type = get<SPIRType>(result_type);
8067
8068 switch (type.vecsize)
8069 {
8070 case 1:
8071 exp = join("float(half(", to_expression(arg), "))");
8072 break;
8073 case 2:
8074 exp = join("float2(half2(", to_expression(arg), "))");
8075 break;
8076 case 3:
8077 exp = join("float3(half3(", to_expression(arg), "))");
8078 break;
8079 case 4:
8080 exp = join("float4(half4(", to_expression(arg), "))");
8081 break;
8082 default:
8083 SPIRV_CROSS_THROW("Illegal argument to OpQuantizeToF16.");
8084 }
8085
8086 emit_op(result_type, id, exp, should_forward(arg));
8087 break;
8088 }
8089
8090 case OpInBoundsAccessChain:
8091 case OpAccessChain:
8092 case OpPtrAccessChain:
8093 if (is_tessellation_shader())
8094 {
8095 if (!emit_tessellation_access_chain(ops, instruction.length))
8096 CompilerGLSL::emit_instruction(instruction);
8097 }
8098 else
8099 CompilerGLSL::emit_instruction(instruction);
8100 fix_up_interpolant_access_chain(ops, instruction.length);
8101 break;
8102
8103 case OpStore:
8104 if (is_out_of_bounds_tessellation_level(ops[0]))
8105 break;
8106
8107 if (maybe_emit_array_assignment(ops[0], ops[1]))
8108 break;
8109
8110 CompilerGLSL::emit_instruction(instruction);
8111 break;
8112
8113 // Compute barriers
8114 case OpMemoryBarrier:
8115 emit_barrier(0, ops[0], ops[1]);
8116 break;
8117
8118 case OpControlBarrier:
8119 // In GLSL a memory barrier is often followed by a control barrier.
8120 // But in MSL, memory barriers are also control barriers, so don't
8121 // emit a simple control barrier if a memory barrier has just been emitted.
8122 if (previous_instruction_opcode != OpMemoryBarrier)
8123 emit_barrier(ops[0], ops[1], ops[2]);
8124 break;
8125
8126 case OpOuterProduct:
8127 {
8128 uint32_t result_type = ops[0];
8129 uint32_t id = ops[1];
8130 uint32_t a = ops[2];
8131 uint32_t b = ops[3];
8132
8133 auto &type = get<SPIRType>(result_type);
8134 string expr = type_to_glsl_constructor(type);
8135 expr += "(";
8136 for (uint32_t col = 0; col < type.columns; col++)
8137 {
8138 expr += to_enclosed_expression(a);
8139 expr += " * ";
8140 expr += to_extract_component_expression(b, col);
8141 if (col + 1 < type.columns)
8142 expr += ", ";
8143 }
8144 expr += ")";
8145 emit_op(result_type, id, expr, should_forward(a) && should_forward(b));
8146 inherit_expression_dependencies(id, a);
8147 inherit_expression_dependencies(id, b);
8148 break;
8149 }
8150
8151 case OpVectorTimesMatrix:
8152 case OpMatrixTimesVector:
8153 {
8154 if (!msl_options.invariant_float_math && !has_decoration(ops[1], DecorationNoContraction))
8155 {
8156 CompilerGLSL::emit_instruction(instruction);
8157 break;
8158 }
8159
8160 // If the matrix needs transpose, just flip the multiply order.
8161 auto *e = maybe_get<SPIRExpression>(ops[opcode == OpMatrixTimesVector ? 2 : 3]);
8162 if (e && e->need_transpose)
8163 {
8164 e->need_transpose = false;
8165 string expr;
8166
8167 if (opcode == OpMatrixTimesVector)
8168 {
8169 expr = join("spvFMulVectorMatrix(", to_enclosed_unpacked_expression(ops[3]), ", ",
8170 to_unpacked_row_major_matrix_expression(ops[2]), ")");
8171 }
8172 else
8173 {
8174 expr = join("spvFMulMatrixVector(", to_unpacked_row_major_matrix_expression(ops[3]), ", ",
8175 to_enclosed_unpacked_expression(ops[2]), ")");
8176 }
8177
8178 bool forward = should_forward(ops[2]) && should_forward(ops[3]);
8179 emit_op(ops[0], ops[1], expr, forward);
8180 e->need_transpose = true;
8181 inherit_expression_dependencies(ops[1], ops[2]);
8182 inherit_expression_dependencies(ops[1], ops[3]);
8183 }
8184 else
8185 {
8186 if (opcode == OpMatrixTimesVector)
8187 MSL_BFOP(spvFMulMatrixVector);
8188 else
8189 MSL_BFOP(spvFMulVectorMatrix);
8190 }
8191 break;
8192 }
8193
8194 case OpMatrixTimesMatrix:
8195 {
8196 if (!msl_options.invariant_float_math && !has_decoration(ops[1], DecorationNoContraction))
8197 {
8198 CompilerGLSL::emit_instruction(instruction);
8199 break;
8200 }
8201
8202 auto *a = maybe_get<SPIRExpression>(ops[2]);
8203 auto *b = maybe_get<SPIRExpression>(ops[3]);
8204
8205 // If both matrices need transpose, we can multiply in flipped order and tag the expression as transposed.
8206 // a^T * b^T = (b * a)^T.
8207 if (a && b && a->need_transpose && b->need_transpose)
8208 {
8209 a->need_transpose = false;
8210 b->need_transpose = false;
8211
8212 auto expr =
8213 join("spvFMulMatrixMatrix(", enclose_expression(to_unpacked_row_major_matrix_expression(ops[3])), ", ",
8214 enclose_expression(to_unpacked_row_major_matrix_expression(ops[2])), ")");
8215
8216 bool forward = should_forward(ops[2]) && should_forward(ops[3]);
8217 auto &e = emit_op(ops[0], ops[1], expr, forward);
8218 e.need_transpose = true;
8219 a->need_transpose = true;
8220 b->need_transpose = true;
8221 inherit_expression_dependencies(ops[1], ops[2]);
8222 inherit_expression_dependencies(ops[1], ops[3]);
8223 }
8224 else
8225 MSL_BFOP(spvFMulMatrixMatrix);
8226
8227 break;
8228 }
8229
8230 case OpIAddCarry:
8231 case OpISubBorrow:
8232 {
8233 uint32_t result_type = ops[0];
8234 uint32_t result_id = ops[1];
8235 uint32_t op0 = ops[2];
8236 uint32_t op1 = ops[3];
8237 auto &type = get<SPIRType>(result_type);
8238 emit_uninitialized_temporary_expression(result_type, result_id);
8239
8240 auto &res_type = get<SPIRType>(type.member_types[1]);
8241 if (opcode == OpIAddCarry)
8242 {
8243 statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", to_enclosed_expression(op0), " + ",
8244 to_enclosed_expression(op1), ";");
8245 statement(to_expression(result_id), ".", to_member_name(type, 1), " = select(", type_to_glsl(res_type),
8246 "(1), ", type_to_glsl(res_type), "(0), ", to_expression(result_id), ".", to_member_name(type, 0),
8247 " >= max(", to_expression(op0), ", ", to_expression(op1), "));");
8248 }
8249 else
8250 {
8251 statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", to_enclosed_expression(op0), " - ",
8252 to_enclosed_expression(op1), ";");
8253 statement(to_expression(result_id), ".", to_member_name(type, 1), " = select(", type_to_glsl(res_type),
8254 "(1), ", type_to_glsl(res_type), "(0), ", to_enclosed_expression(op0),
8255 " >= ", to_enclosed_expression(op1), ");");
8256 }
8257 break;
8258 }
8259
8260 case OpUMulExtended:
8261 case OpSMulExtended:
8262 {
8263 uint32_t result_type = ops[0];
8264 uint32_t result_id = ops[1];
8265 uint32_t op0 = ops[2];
8266 uint32_t op1 = ops[3];
8267 auto &type = get<SPIRType>(result_type);
8268 emit_uninitialized_temporary_expression(result_type, result_id);
8269
8270 statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", to_enclosed_expression(op0), " * ",
8271 to_enclosed_expression(op1), ";");
8272 statement(to_expression(result_id), ".", to_member_name(type, 1), " = mulhi(", to_expression(op0), ", ",
8273 to_expression(op1), ");");
8274 break;
8275 }
8276
8277 case OpArrayLength:
8278 {
8279 auto &type = expression_type(ops[2]);
8280 uint32_t offset = type_struct_member_offset(type, ops[3]);
8281 uint32_t stride = type_struct_member_array_stride(type, ops[3]);
8282
8283 auto expr = join("(", to_buffer_size_expression(ops[2]), " - ", offset, ") / ", stride);
8284 emit_op(ops[0], ops[1], expr, true);
8285 break;
8286 }
8287
8288 // SPV_INTEL_shader_integer_functions2
8289 case OpUCountLeadingZerosINTEL:
8290 MSL_UFOP(clz);
8291 break;
8292
8293 case OpUCountTrailingZerosINTEL:
8294 MSL_UFOP(ctz);
8295 break;
8296
8297 case OpAbsISubINTEL:
8298 case OpAbsUSubINTEL:
8299 MSL_BFOP(absdiff);
8300 break;
8301
8302 case OpIAddSatINTEL:
8303 case OpUAddSatINTEL:
8304 MSL_BFOP(addsat);
8305 break;
8306
8307 case OpIAverageINTEL:
8308 case OpUAverageINTEL:
8309 MSL_BFOP(hadd);
8310 break;
8311
8312 case OpIAverageRoundedINTEL:
8313 case OpUAverageRoundedINTEL:
8314 MSL_BFOP(rhadd);
8315 break;
8316
8317 case OpISubSatINTEL:
8318 case OpUSubSatINTEL:
8319 MSL_BFOP(subsat);
8320 break;
8321
8322 case OpIMul32x16INTEL:
8323 {
8324 uint32_t result_type = ops[0];
8325 uint32_t id = ops[1];
8326 uint32_t a = ops[2], b = ops[3];
8327 bool forward = should_forward(a) && should_forward(b);
8328 emit_op(result_type, id, join("int(short(", to_expression(a), ")) * int(short(", to_expression(b), "))"),
8329 forward);
8330 inherit_expression_dependencies(id, a);
8331 inherit_expression_dependencies(id, b);
8332 break;
8333 }
8334
8335 case OpUMul32x16INTEL:
8336 {
8337 uint32_t result_type = ops[0];
8338 uint32_t id = ops[1];
8339 uint32_t a = ops[2], b = ops[3];
8340 bool forward = should_forward(a) && should_forward(b);
8341 emit_op(result_type, id, join("uint(ushort(", to_expression(a), ")) * uint(ushort(", to_expression(b), "))"),
8342 forward);
8343 inherit_expression_dependencies(id, a);
8344 inherit_expression_dependencies(id, b);
8345 break;
8346 }
8347
8348 // SPV_EXT_demote_to_helper_invocation
8349 case OpDemoteToHelperInvocationEXT:
8350 if (!msl_options.supports_msl_version(2, 3))
8351 SPIRV_CROSS_THROW("discard_fragment() does not formally have demote semantics until MSL 2.3.");
8352 CompilerGLSL::emit_instruction(instruction);
8353 break;
8354
8355 case OpIsHelperInvocationEXT:
8356 if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 3))
8357 SPIRV_CROSS_THROW("simd_is_helper_thread() requires MSL 2.3 on iOS.");
8358 else if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 1))
8359 SPIRV_CROSS_THROW("simd_is_helper_thread() requires MSL 2.1 on macOS.");
8360 emit_op(ops[0], ops[1], "simd_is_helper_thread()", false);
8361 break;
8362
8363 case OpBeginInvocationInterlockEXT:
8364 case OpEndInvocationInterlockEXT:
8365 if (!msl_options.supports_msl_version(2, 0))
8366 SPIRV_CROSS_THROW("Raster order groups require MSL 2.0.");
8367 break; // Nothing to do in the body
8368
8369 default:
8370 CompilerGLSL::emit_instruction(instruction);
8371 break;
8372 }
8373
8374 previous_instruction_opcode = opcode;
8375 }
8376
emit_texture_op(const Instruction & i,bool sparse)8377 void CompilerMSL::emit_texture_op(const Instruction &i, bool sparse)
8378 {
8379 if (sparse)
8380 SPIRV_CROSS_THROW("Sparse feedback not yet supported in MSL.");
8381
8382 if (msl_options.use_framebuffer_fetch_subpasses)
8383 {
8384 auto *ops = stream(i);
8385
8386 uint32_t result_type_id = ops[0];
8387 uint32_t id = ops[1];
8388 uint32_t img = ops[2];
8389
8390 auto &type = expression_type(img);
8391 auto &imgtype = get<SPIRType>(type.self);
8392
8393 // Use Metal's native frame-buffer fetch API for subpass inputs.
8394 if (imgtype.image.dim == DimSubpassData)
8395 {
8396 // Subpass inputs cannot be invalidated,
8397 // so just forward the expression directly.
8398 string expr = to_expression(img);
8399 emit_op(result_type_id, id, expr, true);
8400 return;
8401 }
8402 }
8403
8404 // Fallback to default implementation
8405 CompilerGLSL::emit_texture_op(i, sparse);
8406 }
8407
emit_barrier(uint32_t id_exe_scope,uint32_t id_mem_scope,uint32_t id_mem_sem)8408 void CompilerMSL::emit_barrier(uint32_t id_exe_scope, uint32_t id_mem_scope, uint32_t id_mem_sem)
8409 {
8410 if (get_execution_model() != ExecutionModelGLCompute && get_execution_model() != ExecutionModelTessellationControl)
8411 return;
8412
8413 uint32_t exe_scope = id_exe_scope ? evaluate_constant_u32(id_exe_scope) : uint32_t(ScopeInvocation);
8414 uint32_t mem_scope = id_mem_scope ? evaluate_constant_u32(id_mem_scope) : uint32_t(ScopeInvocation);
8415 // Use the wider of the two scopes (smaller value)
8416 exe_scope = min(exe_scope, mem_scope);
8417
8418 if (msl_options.emulate_subgroups && exe_scope >= ScopeSubgroup && !id_mem_sem)
8419 // In this case, we assume a "subgroup" size of 1. The barrier, then, is a noop.
8420 return;
8421
8422 string bar_stmt;
8423 if ((msl_options.is_ios() && msl_options.supports_msl_version(1, 2)) || msl_options.supports_msl_version(2))
8424 bar_stmt = exe_scope < ScopeSubgroup ? "threadgroup_barrier" : "simdgroup_barrier";
8425 else
8426 bar_stmt = "threadgroup_barrier";
8427 bar_stmt += "(";
8428
8429 uint32_t mem_sem = id_mem_sem ? evaluate_constant_u32(id_mem_sem) : uint32_t(MemorySemanticsMaskNone);
8430
8431 // Use the | operator to combine flags if we can.
8432 if (msl_options.supports_msl_version(1, 2))
8433 {
8434 string mem_flags = "";
8435 // For tesc shaders, this also affects objects in the Output storage class.
8436 // Since in Metal, these are placed in a device buffer, we have to sync device memory here.
8437 if (get_execution_model() == ExecutionModelTessellationControl ||
8438 (mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask)))
8439 mem_flags += "mem_flags::mem_device";
8440
8441 // Fix tessellation patch function processing
8442 if (get_execution_model() == ExecutionModelTessellationControl ||
8443 (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask)))
8444 {
8445 if (!mem_flags.empty())
8446 mem_flags += " | ";
8447 mem_flags += "mem_flags::mem_threadgroup";
8448 }
8449 if (mem_sem & MemorySemanticsImageMemoryMask)
8450 {
8451 if (!mem_flags.empty())
8452 mem_flags += " | ";
8453 mem_flags += "mem_flags::mem_texture";
8454 }
8455
8456 if (mem_flags.empty())
8457 mem_flags = "mem_flags::mem_none";
8458
8459 bar_stmt += mem_flags;
8460 }
8461 else
8462 {
8463 if ((mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask)) &&
8464 (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask)))
8465 bar_stmt += "mem_flags::mem_device_and_threadgroup";
8466 else if (mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask))
8467 bar_stmt += "mem_flags::mem_device";
8468 else if (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask))
8469 bar_stmt += "mem_flags::mem_threadgroup";
8470 else if (mem_sem & MemorySemanticsImageMemoryMask)
8471 bar_stmt += "mem_flags::mem_texture";
8472 else
8473 bar_stmt += "mem_flags::mem_none";
8474 }
8475
8476 bar_stmt += ");";
8477
8478 statement(bar_stmt);
8479
8480 assert(current_emitting_block);
8481 flush_control_dependent_expressions(current_emitting_block->self);
8482 flush_all_active_variables();
8483 }
8484
storage_class_array_is_thread(StorageClass storage)8485 static bool storage_class_array_is_thread(StorageClass storage)
8486 {
8487 switch (storage)
8488 {
8489 case StorageClassInput:
8490 case StorageClassOutput:
8491 case StorageClassGeneric:
8492 case StorageClassFunction:
8493 case StorageClassPrivate:
8494 return true;
8495
8496 default:
8497 return false;
8498 }
8499 }
8500
emit_array_copy(const string & lhs,uint32_t lhs_id,uint32_t rhs_id,StorageClass lhs_storage,StorageClass rhs_storage)8501 void CompilerMSL::emit_array_copy(const string &lhs, uint32_t lhs_id, uint32_t rhs_id,
8502 StorageClass lhs_storage, StorageClass rhs_storage)
8503 {
8504 // Allow Metal to use the array<T> template to make arrays a value type.
8505 // This, however, cannot be used for threadgroup address specifiers, so consider the custom array copy as fallback.
8506 bool lhs_is_thread_storage = storage_class_array_is_thread(lhs_storage);
8507 bool rhs_is_thread_storage = storage_class_array_is_thread(rhs_storage);
8508
8509 bool lhs_is_array_template = lhs_is_thread_storage;
8510 bool rhs_is_array_template = rhs_is_thread_storage;
8511
8512 // Special considerations for stage IO variables.
8513 // If the variable is actually backed by non-user visible device storage, we use array templates for those.
8514 //
8515 // Another special consideration is given to thread local variables which happen to have Offset decorations
8516 // applied to them. Block-like types do not use array templates, so we need to force POD path if we detect
8517 // these scenarios. This check isn't perfect since it would be technically possible to mix and match these things,
8518 // and for a fully correct solution we might have to track array template state through access chains as well,
8519 // but for all reasonable use cases, this should suffice.
8520 // This special case should also only apply to Function/Private storage classes.
8521 // We should not check backing variable for temporaries.
8522 auto *lhs_var = maybe_get_backing_variable(lhs_id);
8523 if (lhs_var && lhs_storage == StorageClassStorageBuffer && storage_class_array_is_thread(lhs_var->storage))
8524 lhs_is_array_template = true;
8525 else if (lhs_var && (lhs_storage == StorageClassFunction || lhs_storage == StorageClassPrivate) &&
8526 type_is_block_like(get<SPIRType>(lhs_var->basetype)))
8527 lhs_is_array_template = false;
8528
8529 auto *rhs_var = maybe_get_backing_variable(rhs_id);
8530 if (rhs_var && rhs_storage == StorageClassStorageBuffer && storage_class_array_is_thread(rhs_var->storage))
8531 rhs_is_array_template = true;
8532 else if (rhs_var && (rhs_storage == StorageClassFunction || rhs_storage == StorageClassPrivate) &&
8533 type_is_block_like(get<SPIRType>(rhs_var->basetype)))
8534 rhs_is_array_template = false;
8535
8536 // If threadgroup storage qualifiers are *not* used:
8537 // Avoid spvCopy* wrapper functions; Otherwise, spvUnsafeArray<> template cannot be used with that storage qualifier.
8538 if (lhs_is_array_template && rhs_is_array_template && !using_builtin_array())
8539 {
8540 statement(lhs, " = ", to_expression(rhs_id), ";");
8541 }
8542 else
8543 {
8544 // Assignment from an array initializer is fine.
8545 auto &type = expression_type(rhs_id);
8546 auto *var = maybe_get_backing_variable(rhs_id);
8547
8548 // Unfortunately, we cannot template on address space in MSL,
8549 // so explicit address space redirection it is ...
8550 bool is_constant = false;
8551 if (ir.ids[rhs_id].get_type() == TypeConstant)
8552 {
8553 is_constant = true;
8554 }
8555 else if (var && var->remapped_variable && var->statically_assigned &&
8556 ir.ids[var->static_expression].get_type() == TypeConstant)
8557 {
8558 is_constant = true;
8559 }
8560 else if (rhs_storage == StorageClassUniform)
8561 {
8562 is_constant = true;
8563 }
8564
8565 // For the case where we have OpLoad triggering an array copy,
8566 // we cannot easily detect this case ahead of time since it's
8567 // context dependent. We might have to force a recompile here
8568 // if this is the only use of array copies in our shader.
8569 if (type.array.size() > 1)
8570 {
8571 if (type.array.size() > kArrayCopyMultidimMax)
8572 SPIRV_CROSS_THROW("Cannot support this many dimensions for arrays of arrays.");
8573 auto func = static_cast<SPVFuncImpl>(SPVFuncImplArrayCopyMultidimBase + type.array.size());
8574 add_spv_func_and_recompile(func);
8575 }
8576 else
8577 add_spv_func_and_recompile(SPVFuncImplArrayCopy);
8578
8579 const char *tag = nullptr;
8580 if (lhs_is_thread_storage && is_constant)
8581 tag = "FromConstantToStack";
8582 else if (lhs_storage == StorageClassWorkgroup && is_constant)
8583 tag = "FromConstantToThreadGroup";
8584 else if (lhs_is_thread_storage && rhs_is_thread_storage)
8585 tag = "FromStackToStack";
8586 else if (lhs_storage == StorageClassWorkgroup && rhs_is_thread_storage)
8587 tag = "FromStackToThreadGroup";
8588 else if (lhs_is_thread_storage && rhs_storage == StorageClassWorkgroup)
8589 tag = "FromThreadGroupToStack";
8590 else if (lhs_storage == StorageClassWorkgroup && rhs_storage == StorageClassWorkgroup)
8591 tag = "FromThreadGroupToThreadGroup";
8592 else if (lhs_storage == StorageClassStorageBuffer && rhs_storage == StorageClassStorageBuffer)
8593 tag = "FromDeviceToDevice";
8594 else if (lhs_storage == StorageClassStorageBuffer && is_constant)
8595 tag = "FromConstantToDevice";
8596 else if (lhs_storage == StorageClassStorageBuffer && rhs_storage == StorageClassWorkgroup)
8597 tag = "FromThreadGroupToDevice";
8598 else if (lhs_storage == StorageClassStorageBuffer && rhs_is_thread_storage)
8599 tag = "FromStackToDevice";
8600 else if (lhs_storage == StorageClassWorkgroup && rhs_storage == StorageClassStorageBuffer)
8601 tag = "FromDeviceToThreadGroup";
8602 else if (lhs_is_thread_storage && rhs_storage == StorageClassStorageBuffer)
8603 tag = "FromDeviceToStack";
8604 else
8605 SPIRV_CROSS_THROW("Unknown storage class used for copying arrays.");
8606
8607 // Pass internal array of spvUnsafeArray<> into wrapper functions
8608 if (lhs_is_array_template && rhs_is_array_template && !msl_options.force_native_arrays)
8609 statement("spvArrayCopy", tag, type.array.size(), "(", lhs, ".elements, ", to_expression(rhs_id), ".elements);");
8610 if (lhs_is_array_template && !msl_options.force_native_arrays)
8611 statement("spvArrayCopy", tag, type.array.size(), "(", lhs, ".elements, ", to_expression(rhs_id), ");");
8612 else if (rhs_is_array_template && !msl_options.force_native_arrays)
8613 statement("spvArrayCopy", tag, type.array.size(), "(", lhs, ", ", to_expression(rhs_id), ".elements);");
8614 else
8615 statement("spvArrayCopy", tag, type.array.size(), "(", lhs, ", ", to_expression(rhs_id), ");");
8616 }
8617 }
8618
get_physical_tess_level_array_size(spv::BuiltIn builtin) const8619 uint32_t CompilerMSL::get_physical_tess_level_array_size(spv::BuiltIn builtin) const
8620 {
8621 if (get_execution_mode_bitset().get(ExecutionModeTriangles))
8622 return builtin == BuiltInTessLevelInner ? 1 : 3;
8623 else
8624 return builtin == BuiltInTessLevelInner ? 2 : 4;
8625 }
8626
8627 // Since MSL does not allow arrays to be copied via simple variable assignment,
8628 // if the LHS and RHS represent an assignment of an entire array, it must be
8629 // implemented by calling an array copy function.
8630 // Returns whether the struct assignment was emitted.
maybe_emit_array_assignment(uint32_t id_lhs,uint32_t id_rhs)8631 bool CompilerMSL::maybe_emit_array_assignment(uint32_t id_lhs, uint32_t id_rhs)
8632 {
8633 // We only care about assignments of an entire array
8634 auto &type = expression_type(id_rhs);
8635 if (type.array.size() == 0)
8636 return false;
8637
8638 auto *var = maybe_get<SPIRVariable>(id_lhs);
8639
8640 // Is this a remapped, static constant? Don't do anything.
8641 if (var && var->remapped_variable && var->statically_assigned)
8642 return true;
8643
8644 if (ir.ids[id_rhs].get_type() == TypeConstant && var && var->deferred_declaration)
8645 {
8646 // Special case, if we end up declaring a variable when assigning the constant array,
8647 // we can avoid the copy by directly assigning the constant expression.
8648 // This is likely necessary to be able to use a variable as a true look-up table, as it is unlikely
8649 // the compiler will be able to optimize the spvArrayCopy() into a constant LUT.
8650 // After a variable has been declared, we can no longer assign constant arrays in MSL unfortunately.
8651 statement(to_expression(id_lhs), " = ", constant_expression(get<SPIRConstant>(id_rhs)), ";");
8652 return true;
8653 }
8654
8655 if (get_execution_model() == ExecutionModelTessellationControl &&
8656 has_decoration(id_lhs, DecorationBuiltIn))
8657 {
8658 auto builtin = BuiltIn(get_decoration(id_lhs, DecorationBuiltIn));
8659 // Need to manually unroll the array store.
8660 if (builtin == BuiltInTessLevelInner || builtin == BuiltInTessLevelOuter)
8661 {
8662 uint32_t array_size = get_physical_tess_level_array_size(builtin);
8663 if (array_size == 1)
8664 statement(to_expression(id_lhs), " = half(", to_expression(id_rhs), "[0]);");
8665 else
8666 {
8667 for (uint32_t i = 0; i < array_size; i++)
8668 statement(to_expression(id_lhs), "[", i, "] = half(", to_expression(id_rhs), "[", i, "]);");
8669 }
8670 return true;
8671 }
8672 }
8673
8674 // Ensure the LHS variable has been declared
8675 auto *p_v_lhs = maybe_get_backing_variable(id_lhs);
8676 if (p_v_lhs)
8677 flush_variable_declaration(p_v_lhs->self);
8678
8679 auto lhs_storage = get_expression_effective_storage_class(id_lhs);
8680 auto rhs_storage = get_expression_effective_storage_class(id_rhs);
8681 emit_array_copy(to_expression(id_lhs), id_lhs, id_rhs, lhs_storage, rhs_storage);
8682 register_write(id_lhs);
8683
8684 return true;
8685 }
8686
8687 // Emits one of the atomic functions. In MSL, the atomic functions operate on pointers
emit_atomic_func_op(uint32_t result_type,uint32_t result_id,const char * op,uint32_t mem_order_1,uint32_t mem_order_2,bool has_mem_order_2,uint32_t obj,uint32_t op1,bool op1_is_pointer,bool op1_is_literal,uint32_t op2)8688 void CompilerMSL::emit_atomic_func_op(uint32_t result_type, uint32_t result_id, const char *op, uint32_t mem_order_1,
8689 uint32_t mem_order_2, bool has_mem_order_2, uint32_t obj, uint32_t op1,
8690 bool op1_is_pointer, bool op1_is_literal, uint32_t op2)
8691 {
8692 string exp = string(op) + "(";
8693
8694 auto &type = get_pointee_type(expression_type(obj));
8695 exp += "(";
8696 auto *var = maybe_get_backing_variable(obj);
8697 if (!var)
8698 SPIRV_CROSS_THROW("No backing variable for atomic operation.");
8699
8700 // Emulate texture2D atomic operations
8701 const auto &res_type = get<SPIRType>(var->basetype);
8702 if (res_type.storage == StorageClassUniformConstant && res_type.basetype == SPIRType::Image)
8703 {
8704 exp += "device";
8705 }
8706 else
8707 {
8708 exp += get_argument_address_space(*var);
8709 }
8710
8711 exp += " atomic_";
8712 exp += type_to_glsl(type);
8713 exp += "*)";
8714
8715 exp += "&";
8716 exp += to_enclosed_expression(obj);
8717
8718 bool is_atomic_compare_exchange_strong = op1_is_pointer && op1;
8719
8720 if (is_atomic_compare_exchange_strong)
8721 {
8722 assert(strcmp(op, "atomic_compare_exchange_weak_explicit") == 0);
8723 assert(op2);
8724 assert(has_mem_order_2);
8725 exp += ", &";
8726 exp += to_name(result_id);
8727 exp += ", ";
8728 exp += to_expression(op2);
8729 exp += ", ";
8730 exp += get_memory_order(mem_order_1);
8731 exp += ", ";
8732 exp += get_memory_order(mem_order_2);
8733 exp += ")";
8734
8735 // MSL only supports the weak atomic compare exchange, so emit a CAS loop here.
8736 // The MSL function returns false if the atomic write fails OR the comparison test fails,
8737 // so we must validate that it wasn't the comparison test that failed before continuing
8738 // the CAS loop, otherwise it will loop infinitely, with the comparison test always failing.
8739 // The function updates the comparitor value from the memory value, so the additional
8740 // comparison test evaluates the memory value against the expected value.
8741 emit_uninitialized_temporary_expression(result_type, result_id);
8742 statement("do");
8743 begin_scope();
8744 statement(to_name(result_id), " = ", to_expression(op1), ";");
8745 end_scope_decl(join("while (!", exp, " && ", to_name(result_id), " == ", to_enclosed_expression(op1), ")"));
8746 }
8747 else
8748 {
8749 assert(strcmp(op, "atomic_compare_exchange_weak_explicit") != 0);
8750 if (op1)
8751 {
8752 if (op1_is_literal)
8753 exp += join(", ", op1);
8754 else
8755 exp += ", " + to_expression(op1);
8756 }
8757 if (op2)
8758 exp += ", " + to_expression(op2);
8759
8760 exp += string(", ") + get_memory_order(mem_order_1);
8761 if (has_mem_order_2)
8762 exp += string(", ") + get_memory_order(mem_order_2);
8763
8764 exp += ")";
8765
8766 if (strcmp(op, "atomic_store_explicit") != 0)
8767 emit_op(result_type, result_id, exp, false);
8768 else
8769 statement(exp, ";");
8770 }
8771
8772 flush_all_atomic_capable_variables();
8773 }
8774
8775 // Metal only supports relaxed memory order for now
get_memory_order(uint32_t)8776 const char *CompilerMSL::get_memory_order(uint32_t)
8777 {
8778 return "memory_order_relaxed";
8779 }
8780
8781 // Override for MSL-specific extension syntax instructions
emit_glsl_op(uint32_t result_type,uint32_t id,uint32_t eop,const uint32_t * args,uint32_t count)8782 void CompilerMSL::emit_glsl_op(uint32_t result_type, uint32_t id, uint32_t eop, const uint32_t *args, uint32_t count)
8783 {
8784 auto op = static_cast<GLSLstd450>(eop);
8785
8786 // If we need to do implicit bitcasts, make sure we do it with the correct type.
8787 uint32_t integer_width = get_integer_width_for_glsl_instruction(op, args, count);
8788 auto int_type = to_signed_basetype(integer_width);
8789 auto uint_type = to_unsigned_basetype(integer_width);
8790
8791 switch (op)
8792 {
8793 case GLSLstd450Atan2:
8794 emit_binary_func_op(result_type, id, args[0], args[1], "atan2");
8795 break;
8796 case GLSLstd450InverseSqrt:
8797 emit_unary_func_op(result_type, id, args[0], "rsqrt");
8798 break;
8799 case GLSLstd450RoundEven:
8800 emit_unary_func_op(result_type, id, args[0], "rint");
8801 break;
8802
8803 case GLSLstd450FindILsb:
8804 {
8805 // In this template version of findLSB, we return T.
8806 auto basetype = expression_type(args[0]).basetype;
8807 emit_unary_func_op_cast(result_type, id, args[0], "spvFindLSB", basetype, basetype);
8808 break;
8809 }
8810
8811 case GLSLstd450FindSMsb:
8812 emit_unary_func_op_cast(result_type, id, args[0], "spvFindSMSB", int_type, int_type);
8813 break;
8814
8815 case GLSLstd450FindUMsb:
8816 emit_unary_func_op_cast(result_type, id, args[0], "spvFindUMSB", uint_type, uint_type);
8817 break;
8818
8819 case GLSLstd450PackSnorm4x8:
8820 emit_unary_func_op(result_type, id, args[0], "pack_float_to_snorm4x8");
8821 break;
8822 case GLSLstd450PackUnorm4x8:
8823 emit_unary_func_op(result_type, id, args[0], "pack_float_to_unorm4x8");
8824 break;
8825 case GLSLstd450PackSnorm2x16:
8826 emit_unary_func_op(result_type, id, args[0], "pack_float_to_snorm2x16");
8827 break;
8828 case GLSLstd450PackUnorm2x16:
8829 emit_unary_func_op(result_type, id, args[0], "pack_float_to_unorm2x16");
8830 break;
8831
8832 case GLSLstd450PackHalf2x16:
8833 {
8834 auto expr = join("as_type<uint>(half2(", to_expression(args[0]), "))");
8835 emit_op(result_type, id, expr, should_forward(args[0]));
8836 inherit_expression_dependencies(id, args[0]);
8837 break;
8838 }
8839
8840 case GLSLstd450UnpackSnorm4x8:
8841 emit_unary_func_op(result_type, id, args[0], "unpack_snorm4x8_to_float");
8842 break;
8843 case GLSLstd450UnpackUnorm4x8:
8844 emit_unary_func_op(result_type, id, args[0], "unpack_unorm4x8_to_float");
8845 break;
8846 case GLSLstd450UnpackSnorm2x16:
8847 emit_unary_func_op(result_type, id, args[0], "unpack_snorm2x16_to_float");
8848 break;
8849 case GLSLstd450UnpackUnorm2x16:
8850 emit_unary_func_op(result_type, id, args[0], "unpack_unorm2x16_to_float");
8851 break;
8852
8853 case GLSLstd450UnpackHalf2x16:
8854 {
8855 auto expr = join("float2(as_type<half2>(", to_expression(args[0]), "))");
8856 emit_op(result_type, id, expr, should_forward(args[0]));
8857 inherit_expression_dependencies(id, args[0]);
8858 break;
8859 }
8860
8861 case GLSLstd450PackDouble2x32:
8862 emit_unary_func_op(result_type, id, args[0], "unsupported_GLSLstd450PackDouble2x32"); // Currently unsupported
8863 break;
8864 case GLSLstd450UnpackDouble2x32:
8865 emit_unary_func_op(result_type, id, args[0], "unsupported_GLSLstd450UnpackDouble2x32"); // Currently unsupported
8866 break;
8867
8868 case GLSLstd450MatrixInverse:
8869 {
8870 auto &mat_type = get<SPIRType>(result_type);
8871 switch (mat_type.columns)
8872 {
8873 case 2:
8874 emit_unary_func_op(result_type, id, args[0], "spvInverse2x2");
8875 break;
8876 case 3:
8877 emit_unary_func_op(result_type, id, args[0], "spvInverse3x3");
8878 break;
8879 case 4:
8880 emit_unary_func_op(result_type, id, args[0], "spvInverse4x4");
8881 break;
8882 default:
8883 break;
8884 }
8885 break;
8886 }
8887
8888 case GLSLstd450FMin:
8889 // If the result type isn't float, don't bother calling the specific
8890 // precise::/fast:: version. Metal doesn't have those for half and
8891 // double types.
8892 if (get<SPIRType>(result_type).basetype != SPIRType::Float)
8893 emit_binary_func_op(result_type, id, args[0], args[1], "min");
8894 else
8895 emit_binary_func_op(result_type, id, args[0], args[1], "fast::min");
8896 break;
8897
8898 case GLSLstd450FMax:
8899 if (get<SPIRType>(result_type).basetype != SPIRType::Float)
8900 emit_binary_func_op(result_type, id, args[0], args[1], "max");
8901 else
8902 emit_binary_func_op(result_type, id, args[0], args[1], "fast::max");
8903 break;
8904
8905 case GLSLstd450FClamp:
8906 // TODO: If args[1] is 0 and args[2] is 1, emit a saturate() call.
8907 if (get<SPIRType>(result_type).basetype != SPIRType::Float)
8908 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "clamp");
8909 else
8910 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "fast::clamp");
8911 break;
8912
8913 case GLSLstd450NMin:
8914 if (get<SPIRType>(result_type).basetype != SPIRType::Float)
8915 emit_binary_func_op(result_type, id, args[0], args[1], "min");
8916 else
8917 emit_binary_func_op(result_type, id, args[0], args[1], "precise::min");
8918 break;
8919
8920 case GLSLstd450NMax:
8921 if (get<SPIRType>(result_type).basetype != SPIRType::Float)
8922 emit_binary_func_op(result_type, id, args[0], args[1], "max");
8923 else
8924 emit_binary_func_op(result_type, id, args[0], args[1], "precise::max");
8925 break;
8926
8927 case GLSLstd450NClamp:
8928 // TODO: If args[1] is 0 and args[2] is 1, emit a saturate() call.
8929 if (get<SPIRType>(result_type).basetype != SPIRType::Float)
8930 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "clamp");
8931 else
8932 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "precise::clamp");
8933 break;
8934
8935 case GLSLstd450InterpolateAtCentroid:
8936 {
8937 // We can't just emit the expression normally, because the qualified name contains a call to the default
8938 // interpolate method, or refers to a local variable. We saved the interface index we need; use it to construct
8939 // the base for the method call.
8940 uint32_t interface_index = get_extended_decoration(args[0], SPIRVCrossDecorationInterfaceMemberIndex);
8941 string component;
8942 if (has_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr))
8943 {
8944 uint32_t index_expr = get_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr);
8945 auto *c = maybe_get<SPIRConstant>(index_expr);
8946 if (!c || c->specialization)
8947 component = join("[", to_expression(index_expr), "]");
8948 else
8949 component = join(".", index_to_swizzle(c->scalar()));
8950 }
8951 emit_op(result_type, id,
8952 join(to_name(stage_in_var_id), ".", to_member_name(get_stage_in_struct_type(), interface_index),
8953 ".interpolate_at_centroid()", component),
8954 should_forward(args[0]));
8955 break;
8956 }
8957
8958 case GLSLstd450InterpolateAtSample:
8959 {
8960 uint32_t interface_index = get_extended_decoration(args[0], SPIRVCrossDecorationInterfaceMemberIndex);
8961 string component;
8962 if (has_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr))
8963 {
8964 uint32_t index_expr = get_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr);
8965 auto *c = maybe_get<SPIRConstant>(index_expr);
8966 if (!c || c->specialization)
8967 component = join("[", to_expression(index_expr), "]");
8968 else
8969 component = join(".", index_to_swizzle(c->scalar()));
8970 }
8971 emit_op(result_type, id,
8972 join(to_name(stage_in_var_id), ".", to_member_name(get_stage_in_struct_type(), interface_index),
8973 ".interpolate_at_sample(", to_expression(args[1]), ")", component),
8974 should_forward(args[0]) && should_forward(args[1]));
8975 break;
8976 }
8977
8978 case GLSLstd450InterpolateAtOffset:
8979 {
8980 uint32_t interface_index = get_extended_decoration(args[0], SPIRVCrossDecorationInterfaceMemberIndex);
8981 string component;
8982 if (has_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr))
8983 {
8984 uint32_t index_expr = get_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr);
8985 auto *c = maybe_get<SPIRConstant>(index_expr);
8986 if (!c || c->specialization)
8987 component = join("[", to_expression(index_expr), "]");
8988 else
8989 component = join(".", index_to_swizzle(c->scalar()));
8990 }
8991 // Like Direct3D, Metal puts the (0, 0) at the upper-left corner, not the center as SPIR-V and GLSL do.
8992 // Offset the offset by (1/2 - 1/16), or 0.4375, to compensate for this.
8993 // It has to be (1/2 - 1/16) and not 1/2, or several CTS tests subtly break on Intel.
8994 emit_op(result_type, id,
8995 join(to_name(stage_in_var_id), ".", to_member_name(get_stage_in_struct_type(), interface_index),
8996 ".interpolate_at_offset(", to_expression(args[1]), " + 0.4375)", component),
8997 should_forward(args[0]) && should_forward(args[1]));
8998 break;
8999 }
9000
9001 case GLSLstd450Distance:
9002 // MSL does not support scalar versions here.
9003 if (expression_type(args[0]).vecsize == 1)
9004 {
9005 // Equivalent to length(a - b) -> abs(a - b).
9006 emit_op(result_type, id,
9007 join("abs(", to_enclosed_unpacked_expression(args[0]), " - ",
9008 to_enclosed_unpacked_expression(args[1]), ")"),
9009 should_forward(args[0]) && should_forward(args[1]));
9010 inherit_expression_dependencies(id, args[0]);
9011 inherit_expression_dependencies(id, args[1]);
9012 }
9013 else
9014 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
9015 break;
9016
9017 case GLSLstd450Length:
9018 // MSL does not support scalar versions here.
9019 if (expression_type(args[0]).vecsize == 1)
9020 {
9021 // Equivalent to abs().
9022 emit_unary_func_op(result_type, id, args[0], "abs");
9023 }
9024 else
9025 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
9026 break;
9027
9028 case GLSLstd450Normalize:
9029 // MSL does not support scalar versions here.
9030 if (expression_type(args[0]).vecsize == 1)
9031 {
9032 // Returns -1 or 1 for valid input, sign() does the job.
9033 emit_unary_func_op(result_type, id, args[0], "sign");
9034 }
9035 else
9036 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
9037 break;
9038
9039 case GLSLstd450Reflect:
9040 if (get<SPIRType>(result_type).vecsize == 1)
9041 emit_binary_func_op(result_type, id, args[0], args[1], "spvReflect");
9042 else
9043 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
9044 break;
9045
9046 case GLSLstd450Refract:
9047 if (get<SPIRType>(result_type).vecsize == 1)
9048 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "spvRefract");
9049 else
9050 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
9051 break;
9052
9053 case GLSLstd450FaceForward:
9054 if (get<SPIRType>(result_type).vecsize == 1)
9055 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "spvFaceForward");
9056 else
9057 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
9058 break;
9059
9060 case GLSLstd450Modf:
9061 case GLSLstd450Frexp:
9062 {
9063 // Special case. If the variable is a scalar access chain, we cannot use it directly. We have to emit a temporary.
9064 auto *ptr = maybe_get<SPIRExpression>(args[1]);
9065 if (ptr && ptr->access_chain && is_scalar(expression_type(args[1])))
9066 {
9067 register_call_out_argument(args[1]);
9068 forced_temporaries.insert(id);
9069
9070 // Need to create temporaries and copy over to access chain after.
9071 // We cannot directly take the reference of a vector swizzle in MSL, even if it's scalar ...
9072 uint32_t &tmp_id = extra_sub_expressions[id];
9073 if (!tmp_id)
9074 tmp_id = ir.increase_bound_by(1);
9075
9076 uint32_t tmp_type_id = get_pointee_type_id(ptr->expression_type);
9077 emit_uninitialized_temporary_expression(tmp_type_id, tmp_id);
9078 emit_binary_func_op(result_type, id, args[0], tmp_id, eop == GLSLstd450Modf ? "modf" : "frexp");
9079 statement(to_expression(args[1]), " = ", to_expression(tmp_id), ";");
9080 }
9081 else
9082 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
9083 break;
9084 }
9085
9086 default:
9087 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
9088 break;
9089 }
9090 }
9091
emit_spv_amd_shader_trinary_minmax_op(uint32_t result_type,uint32_t id,uint32_t eop,const uint32_t * args,uint32_t count)9092 void CompilerMSL::emit_spv_amd_shader_trinary_minmax_op(uint32_t result_type, uint32_t id, uint32_t eop,
9093 const uint32_t *args, uint32_t count)
9094 {
9095 enum AMDShaderTrinaryMinMax
9096 {
9097 FMin3AMD = 1,
9098 UMin3AMD = 2,
9099 SMin3AMD = 3,
9100 FMax3AMD = 4,
9101 UMax3AMD = 5,
9102 SMax3AMD = 6,
9103 FMid3AMD = 7,
9104 UMid3AMD = 8,
9105 SMid3AMD = 9
9106 };
9107
9108 if (!msl_options.supports_msl_version(2, 1))
9109 SPIRV_CROSS_THROW("Trinary min/max functions require MSL 2.1.");
9110
9111 auto op = static_cast<AMDShaderTrinaryMinMax>(eop);
9112
9113 switch (op)
9114 {
9115 case FMid3AMD:
9116 case UMid3AMD:
9117 case SMid3AMD:
9118 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "median3");
9119 break;
9120 default:
9121 CompilerGLSL::emit_spv_amd_shader_trinary_minmax_op(result_type, id, eop, args, count);
9122 break;
9123 }
9124 }
9125
9126 // Emit a structure declaration for the specified interface variable.
emit_interface_block(uint32_t ib_var_id)9127 void CompilerMSL::emit_interface_block(uint32_t ib_var_id)
9128 {
9129 if (ib_var_id)
9130 {
9131 auto &ib_var = get<SPIRVariable>(ib_var_id);
9132 auto &ib_type = get_variable_data_type(ib_var);
9133 //assert(ib_type.basetype == SPIRType::Struct && !ib_type.member_types.empty());
9134 assert(ib_type.basetype == SPIRType::Struct);
9135 emit_struct(ib_type);
9136 }
9137 }
9138
9139 // Emits the declaration signature of the specified function.
9140 // If this is the entry point function, Metal-specific return value and function arguments are added.
emit_function_prototype(SPIRFunction & func,const Bitset &)9141 void CompilerMSL::emit_function_prototype(SPIRFunction &func, const Bitset &)
9142 {
9143 if (func.self != ir.default_entry_point)
9144 add_function_overload(func);
9145
9146 local_variable_names = resource_names;
9147 string decl;
9148
9149 processing_entry_point = func.self == ir.default_entry_point;
9150
9151 // Metal helper functions must be static force-inline otherwise they will cause problems when linked together in a single Metallib.
9152 if (!processing_entry_point)
9153 statement(force_inline);
9154
9155 auto &type = get<SPIRType>(func.return_type);
9156
9157 if (!type.array.empty() && msl_options.force_native_arrays)
9158 {
9159 // We cannot return native arrays in MSL, so "return" through an out variable.
9160 decl += "void";
9161 }
9162 else
9163 {
9164 decl += func_type_decl(type);
9165 }
9166
9167 decl += " ";
9168 decl += to_name(func.self);
9169 decl += "(";
9170
9171 if (!type.array.empty() && msl_options.force_native_arrays)
9172 {
9173 // Fake arrays returns by writing to an out array instead.
9174 decl += "thread ";
9175 decl += type_to_glsl(type);
9176 decl += " (&spvReturnValue)";
9177 decl += type_to_array_glsl(type);
9178 if (!func.arguments.empty())
9179 decl += ", ";
9180 }
9181
9182 if (processing_entry_point)
9183 {
9184 if (msl_options.argument_buffers)
9185 decl += entry_point_args_argument_buffer(!func.arguments.empty());
9186 else
9187 decl += entry_point_args_classic(!func.arguments.empty());
9188
9189 // If entry point function has variables that require early declaration,
9190 // ensure they each have an empty initializer, creating one if needed.
9191 // This is done at this late stage because the initialization expression
9192 // is cleared after each compilation pass.
9193 for (auto var_id : vars_needing_early_declaration)
9194 {
9195 auto &ed_var = get<SPIRVariable>(var_id);
9196 ID &initializer = ed_var.initializer;
9197 if (!initializer)
9198 initializer = ir.increase_bound_by(1);
9199
9200 // Do not override proper initializers.
9201 if (ir.ids[initializer].get_type() == TypeNone || ir.ids[initializer].get_type() == TypeExpression)
9202 set<SPIRExpression>(ed_var.initializer, "{}", ed_var.basetype, true);
9203 }
9204 }
9205
9206 for (auto &arg : func.arguments)
9207 {
9208 uint32_t name_id = arg.id;
9209
9210 auto *var = maybe_get<SPIRVariable>(arg.id);
9211 if (var)
9212 {
9213 // If we need to modify the name of the variable, make sure we modify the original variable.
9214 // Our alias is just a shadow variable.
9215 if (arg.alias_global_variable && var->basevariable)
9216 name_id = var->basevariable;
9217
9218 var->parameter = &arg; // Hold a pointer to the parameter so we can invalidate the readonly field if needed.
9219 }
9220
9221 add_local_variable_name(name_id);
9222
9223 decl += argument_decl(arg);
9224
9225 bool is_dynamic_img_sampler = has_extended_decoration(arg.id, SPIRVCrossDecorationDynamicImageSampler);
9226
9227 auto &arg_type = get<SPIRType>(arg.type);
9228 if (arg_type.basetype == SPIRType::SampledImage && !is_dynamic_img_sampler)
9229 {
9230 // Manufacture automatic plane args for multiplanar texture
9231 uint32_t planes = 1;
9232 if (auto *constexpr_sampler = find_constexpr_sampler(name_id))
9233 if (constexpr_sampler->ycbcr_conversion_enable)
9234 planes = constexpr_sampler->planes;
9235 for (uint32_t i = 1; i < planes; i++)
9236 decl += join(", ", argument_decl(arg), plane_name_suffix, i);
9237
9238 // Manufacture automatic sampler arg for SampledImage texture
9239 if (arg_type.image.dim != DimBuffer)
9240 decl += join(", thread const ", sampler_type(arg_type, arg.id), " ", to_sampler_expression(arg.id));
9241 }
9242
9243 // Manufacture automatic swizzle arg.
9244 if (msl_options.swizzle_texture_samples && has_sampled_images && is_sampled_image_type(arg_type) &&
9245 !is_dynamic_img_sampler)
9246 {
9247 bool arg_is_array = !arg_type.array.empty();
9248 decl += join(", constant uint", arg_is_array ? "* " : "& ", to_swizzle_expression(arg.id));
9249 }
9250
9251 if (buffers_requiring_array_length.count(name_id))
9252 {
9253 bool arg_is_array = !arg_type.array.empty();
9254 decl += join(", constant uint", arg_is_array ? "* " : "& ", to_buffer_size_expression(name_id));
9255 }
9256
9257 if (&arg != &func.arguments.back())
9258 decl += ", ";
9259 }
9260
9261 decl += ")";
9262 statement(decl);
9263 }
9264
needs_chroma_reconstruction(const MSLConstexprSampler * constexpr_sampler)9265 static bool needs_chroma_reconstruction(const MSLConstexprSampler *constexpr_sampler)
9266 {
9267 // For now, only multiplanar images need explicit reconstruction. GBGR and BGRG images
9268 // use implicit reconstruction.
9269 return constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable && constexpr_sampler->planes > 1;
9270 }
9271
9272 // Returns the texture sampling function string for the specified image and sampling characteristics.
to_function_name(const TextureFunctionNameArguments & args)9273 string CompilerMSL::to_function_name(const TextureFunctionNameArguments &args)
9274 {
9275 VariableID img = args.base.img;
9276 auto &imgtype = *args.base.imgtype;
9277
9278 const MSLConstexprSampler *constexpr_sampler = nullptr;
9279 bool is_dynamic_img_sampler = false;
9280 if (auto *var = maybe_get_backing_variable(img))
9281 {
9282 constexpr_sampler = find_constexpr_sampler(var->basevariable ? var->basevariable : VariableID(var->self));
9283 is_dynamic_img_sampler = has_extended_decoration(var->self, SPIRVCrossDecorationDynamicImageSampler);
9284 }
9285
9286 // Special-case gather. We have to alter the component being looked up
9287 // in the swizzle case.
9288 if (msl_options.swizzle_texture_samples && args.base.is_gather && !is_dynamic_img_sampler &&
9289 (!constexpr_sampler || !constexpr_sampler->ycbcr_conversion_enable))
9290 {
9291 add_spv_func_and_recompile(imgtype.image.depth ? SPVFuncImplGatherCompareSwizzle : SPVFuncImplGatherSwizzle);
9292 return imgtype.image.depth ? "spvGatherCompareSwizzle" : "spvGatherSwizzle";
9293 }
9294
9295 auto *combined = maybe_get<SPIRCombinedImageSampler>(img);
9296
9297 // Texture reference
9298 string fname;
9299 if (needs_chroma_reconstruction(constexpr_sampler) && !is_dynamic_img_sampler)
9300 {
9301 if (constexpr_sampler->planes != 2 && constexpr_sampler->planes != 3)
9302 SPIRV_CROSS_THROW("Unhandled number of color image planes!");
9303 // 444 images aren't downsampled, so we don't need to do linear filtering.
9304 if (constexpr_sampler->resolution == MSL_FORMAT_RESOLUTION_444 ||
9305 constexpr_sampler->chroma_filter == MSL_SAMPLER_FILTER_NEAREST)
9306 {
9307 if (constexpr_sampler->planes == 2)
9308 add_spv_func_and_recompile(SPVFuncImplChromaReconstructNearest2Plane);
9309 else
9310 add_spv_func_and_recompile(SPVFuncImplChromaReconstructNearest3Plane);
9311 fname = "spvChromaReconstructNearest";
9312 }
9313 else // Linear with a downsampled format
9314 {
9315 fname = "spvChromaReconstructLinear";
9316 switch (constexpr_sampler->resolution)
9317 {
9318 case MSL_FORMAT_RESOLUTION_444:
9319 assert(false);
9320 break; // not reached
9321 case MSL_FORMAT_RESOLUTION_422:
9322 switch (constexpr_sampler->x_chroma_offset)
9323 {
9324 case MSL_CHROMA_LOCATION_COSITED_EVEN:
9325 if (constexpr_sampler->planes == 2)
9326 add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear422CositedEven2Plane);
9327 else
9328 add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear422CositedEven3Plane);
9329 fname += "422CositedEven";
9330 break;
9331 case MSL_CHROMA_LOCATION_MIDPOINT:
9332 if (constexpr_sampler->planes == 2)
9333 add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear422Midpoint2Plane);
9334 else
9335 add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear422Midpoint3Plane);
9336 fname += "422Midpoint";
9337 break;
9338 default:
9339 SPIRV_CROSS_THROW("Invalid chroma location.");
9340 }
9341 break;
9342 case MSL_FORMAT_RESOLUTION_420:
9343 fname += "420";
9344 switch (constexpr_sampler->x_chroma_offset)
9345 {
9346 case MSL_CHROMA_LOCATION_COSITED_EVEN:
9347 switch (constexpr_sampler->y_chroma_offset)
9348 {
9349 case MSL_CHROMA_LOCATION_COSITED_EVEN:
9350 if (constexpr_sampler->planes == 2)
9351 add_spv_func_and_recompile(
9352 SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven2Plane);
9353 else
9354 add_spv_func_and_recompile(
9355 SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven3Plane);
9356 fname += "XCositedEvenYCositedEven";
9357 break;
9358 case MSL_CHROMA_LOCATION_MIDPOINT:
9359 if (constexpr_sampler->planes == 2)
9360 add_spv_func_and_recompile(
9361 SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint2Plane);
9362 else
9363 add_spv_func_and_recompile(
9364 SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint3Plane);
9365 fname += "XCositedEvenYMidpoint";
9366 break;
9367 default:
9368 SPIRV_CROSS_THROW("Invalid Y chroma location.");
9369 }
9370 break;
9371 case MSL_CHROMA_LOCATION_MIDPOINT:
9372 switch (constexpr_sampler->y_chroma_offset)
9373 {
9374 case MSL_CHROMA_LOCATION_COSITED_EVEN:
9375 if (constexpr_sampler->planes == 2)
9376 add_spv_func_and_recompile(
9377 SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven2Plane);
9378 else
9379 add_spv_func_and_recompile(
9380 SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven3Plane);
9381 fname += "XMidpointYCositedEven";
9382 break;
9383 case MSL_CHROMA_LOCATION_MIDPOINT:
9384 if (constexpr_sampler->planes == 2)
9385 add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint2Plane);
9386 else
9387 add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint3Plane);
9388 fname += "XMidpointYMidpoint";
9389 break;
9390 default:
9391 SPIRV_CROSS_THROW("Invalid Y chroma location.");
9392 }
9393 break;
9394 default:
9395 SPIRV_CROSS_THROW("Invalid X chroma location.");
9396 }
9397 break;
9398 default:
9399 SPIRV_CROSS_THROW("Invalid format resolution.");
9400 }
9401 }
9402 }
9403 else
9404 {
9405 fname = to_expression(combined ? combined->image : img) + ".";
9406
9407 // Texture function and sampler
9408 if (args.base.is_fetch)
9409 fname += "read";
9410 else if (args.base.is_gather)
9411 fname += "gather";
9412 else
9413 fname += "sample";
9414
9415 if (args.has_dref)
9416 fname += "_compare";
9417 }
9418
9419 return fname;
9420 }
9421
convert_to_f32(const string & expr,uint32_t components)9422 string CompilerMSL::convert_to_f32(const string &expr, uint32_t components)
9423 {
9424 SPIRType t;
9425 t.basetype = SPIRType::Float;
9426 t.vecsize = components;
9427 t.columns = 1;
9428 return join(type_to_glsl_constructor(t), "(", expr, ")");
9429 }
9430
sampling_type_needs_f32_conversion(const SPIRType & type)9431 static inline bool sampling_type_needs_f32_conversion(const SPIRType &type)
9432 {
9433 // Double is not supported to begin with, but doesn't hurt to check for completion.
9434 return type.basetype == SPIRType::Half || type.basetype == SPIRType::Double;
9435 }
9436
9437 // Returns the function args for a texture sampling function for the specified image and sampling characteristics.
to_function_args(const TextureFunctionArguments & args,bool * p_forward)9438 string CompilerMSL::to_function_args(const TextureFunctionArguments &args, bool *p_forward)
9439 {
9440 VariableID img = args.base.img;
9441 auto &imgtype = *args.base.imgtype;
9442 uint32_t lod = args.lod;
9443 uint32_t grad_x = args.grad_x;
9444 uint32_t grad_y = args.grad_y;
9445 uint32_t bias = args.bias;
9446
9447 const MSLConstexprSampler *constexpr_sampler = nullptr;
9448 bool is_dynamic_img_sampler = false;
9449 if (auto *var = maybe_get_backing_variable(img))
9450 {
9451 constexpr_sampler = find_constexpr_sampler(var->basevariable ? var->basevariable : VariableID(var->self));
9452 is_dynamic_img_sampler = has_extended_decoration(var->self, SPIRVCrossDecorationDynamicImageSampler);
9453 }
9454
9455 string farg_str;
9456 bool forward = true;
9457
9458 if (!is_dynamic_img_sampler)
9459 {
9460 // Texture reference (for some cases)
9461 if (needs_chroma_reconstruction(constexpr_sampler))
9462 {
9463 // Multiplanar images need two or three textures.
9464 farg_str += to_expression(img);
9465 for (uint32_t i = 1; i < constexpr_sampler->planes; i++)
9466 farg_str += join(", ", to_expression(img), plane_name_suffix, i);
9467 }
9468 else if ((!constexpr_sampler || !constexpr_sampler->ycbcr_conversion_enable) &&
9469 msl_options.swizzle_texture_samples && args.base.is_gather)
9470 {
9471 auto *combined = maybe_get<SPIRCombinedImageSampler>(img);
9472 farg_str += to_expression(combined ? combined->image : img);
9473 }
9474
9475 // Sampler reference
9476 if (!args.base.is_fetch)
9477 {
9478 if (!farg_str.empty())
9479 farg_str += ", ";
9480 farg_str += to_sampler_expression(img);
9481 }
9482
9483 if ((!constexpr_sampler || !constexpr_sampler->ycbcr_conversion_enable) &&
9484 msl_options.swizzle_texture_samples && args.base.is_gather)
9485 {
9486 // Add the swizzle constant from the swizzle buffer.
9487 farg_str += ", " + to_swizzle_expression(img);
9488 used_swizzle_buffer = true;
9489 }
9490
9491 // Swizzled gather puts the component before the other args, to allow template
9492 // deduction to work.
9493 if (args.component && msl_options.swizzle_texture_samples)
9494 {
9495 forward = should_forward(args.component);
9496 farg_str += ", " + to_component_argument(args.component);
9497 }
9498 }
9499
9500 // Texture coordinates
9501 forward = forward && should_forward(args.coord);
9502 auto coord_expr = to_enclosed_expression(args.coord);
9503 auto &coord_type = expression_type(args.coord);
9504 bool coord_is_fp = type_is_floating_point(coord_type);
9505 bool is_cube_fetch = false;
9506
9507 string tex_coords = coord_expr;
9508 uint32_t alt_coord_component = 0;
9509
9510 switch (imgtype.image.dim)
9511 {
9512
9513 case Dim1D:
9514 if (coord_type.vecsize > 1)
9515 tex_coords = enclose_expression(tex_coords) + ".x";
9516
9517 if (args.base.is_fetch)
9518 tex_coords = "uint(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
9519 else if (sampling_type_needs_f32_conversion(coord_type))
9520 tex_coords = convert_to_f32(tex_coords, 1);
9521
9522 if (msl_options.texture_1D_as_2D)
9523 {
9524 if (args.base.is_fetch)
9525 tex_coords = "uint2(" + tex_coords + ", 0)";
9526 else
9527 tex_coords = "float2(" + tex_coords + ", 0.5)";
9528 }
9529
9530 alt_coord_component = 1;
9531 break;
9532
9533 case DimBuffer:
9534 if (coord_type.vecsize > 1)
9535 tex_coords = enclose_expression(tex_coords) + ".x";
9536
9537 if (msl_options.texture_buffer_native)
9538 {
9539 tex_coords = "uint(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
9540 }
9541 else
9542 {
9543 // Metal texel buffer textures are 2D, so convert 1D coord to 2D.
9544 // Support for Metal 2.1's new texture_buffer type.
9545 if (args.base.is_fetch)
9546 {
9547 if (msl_options.texel_buffer_texture_width > 0)
9548 {
9549 tex_coords = "spvTexelBufferCoord(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
9550 }
9551 else
9552 {
9553 tex_coords = "spvTexelBufferCoord(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ", " +
9554 to_expression(img) + ")";
9555 }
9556 }
9557 }
9558
9559 alt_coord_component = 1;
9560 break;
9561
9562 case DimSubpassData:
9563 // If we're using Metal's native frame-buffer fetch API for subpass inputs,
9564 // this path will not be hit.
9565 tex_coords = "uint2(gl_FragCoord.xy)";
9566 alt_coord_component = 2;
9567 break;
9568
9569 case Dim2D:
9570 if (coord_type.vecsize > 2)
9571 tex_coords = enclose_expression(tex_coords) + ".xy";
9572
9573 if (args.base.is_fetch)
9574 tex_coords = "uint2(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
9575 else if (sampling_type_needs_f32_conversion(coord_type))
9576 tex_coords = convert_to_f32(tex_coords, 2);
9577
9578 alt_coord_component = 2;
9579 break;
9580
9581 case Dim3D:
9582 if (coord_type.vecsize > 3)
9583 tex_coords = enclose_expression(tex_coords) + ".xyz";
9584
9585 if (args.base.is_fetch)
9586 tex_coords = "uint3(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
9587 else if (sampling_type_needs_f32_conversion(coord_type))
9588 tex_coords = convert_to_f32(tex_coords, 3);
9589
9590 alt_coord_component = 3;
9591 break;
9592
9593 case DimCube:
9594 if (args.base.is_fetch)
9595 {
9596 is_cube_fetch = true;
9597 tex_coords += ".xy";
9598 tex_coords = "uint2(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
9599 }
9600 else
9601 {
9602 if (coord_type.vecsize > 3)
9603 tex_coords = enclose_expression(tex_coords) + ".xyz";
9604 }
9605
9606 if (sampling_type_needs_f32_conversion(coord_type))
9607 tex_coords = convert_to_f32(tex_coords, 3);
9608
9609 alt_coord_component = 3;
9610 break;
9611
9612 default:
9613 break;
9614 }
9615
9616 if (args.base.is_fetch && (args.offset || args.coffset))
9617 {
9618 uint32_t offset_expr = args.offset ? args.offset : args.coffset;
9619 // Fetch offsets must be applied directly to the coordinate.
9620 forward = forward && should_forward(offset_expr);
9621 auto &type = expression_type(offset_expr);
9622 if (imgtype.image.dim == Dim1D && msl_options.texture_1D_as_2D)
9623 {
9624 if (type.basetype != SPIRType::UInt)
9625 tex_coords += join(" + uint2(", bitcast_expression(SPIRType::UInt, offset_expr), ", 0)");
9626 else
9627 tex_coords += join(" + uint2(", to_enclosed_expression(offset_expr), ", 0)");
9628 }
9629 else
9630 {
9631 if (type.basetype != SPIRType::UInt)
9632 tex_coords += " + " + bitcast_expression(SPIRType::UInt, offset_expr);
9633 else
9634 tex_coords += " + " + to_enclosed_expression(offset_expr);
9635 }
9636 }
9637
9638 // If projection, use alt coord as divisor
9639 if (args.base.is_proj)
9640 {
9641 if (sampling_type_needs_f32_conversion(coord_type))
9642 tex_coords += " / " + convert_to_f32(to_extract_component_expression(args.coord, alt_coord_component), 1);
9643 else
9644 tex_coords += " / " + to_extract_component_expression(args.coord, alt_coord_component);
9645 }
9646
9647 if (!farg_str.empty())
9648 farg_str += ", ";
9649
9650 if (imgtype.image.dim == DimCube && imgtype.image.arrayed && msl_options.emulate_cube_array)
9651 {
9652 farg_str += "spvCubemapTo2DArrayFace(" + tex_coords + ").xy";
9653
9654 if (is_cube_fetch)
9655 farg_str += ", uint(" + to_extract_component_expression(args.coord, 2) + ")";
9656 else
9657 farg_str +=
9658 ", uint(spvCubemapTo2DArrayFace(" + tex_coords + ").z) + (uint(" +
9659 round_fp_tex_coords(to_extract_component_expression(args.coord, alt_coord_component), coord_is_fp) +
9660 ") * 6u)";
9661
9662 add_spv_func_and_recompile(SPVFuncImplCubemapTo2DArrayFace);
9663 }
9664 else
9665 {
9666 farg_str += tex_coords;
9667
9668 // If fetch from cube, add face explicitly
9669 if (is_cube_fetch)
9670 {
9671 // Special case for cube arrays, face and layer are packed in one dimension.
9672 if (imgtype.image.arrayed)
9673 farg_str += ", uint(" + to_extract_component_expression(args.coord, 2) + ") % 6u";
9674 else
9675 farg_str +=
9676 ", uint(" + round_fp_tex_coords(to_extract_component_expression(args.coord, 2), coord_is_fp) + ")";
9677 }
9678
9679 // If array, use alt coord
9680 if (imgtype.image.arrayed)
9681 {
9682 // Special case for cube arrays, face and layer are packed in one dimension.
9683 if (imgtype.image.dim == DimCube && args.base.is_fetch)
9684 {
9685 farg_str += ", uint(" + to_extract_component_expression(args.coord, 2) + ") / 6u";
9686 }
9687 else
9688 {
9689 farg_str +=
9690 ", uint(" +
9691 round_fp_tex_coords(to_extract_component_expression(args.coord, alt_coord_component), coord_is_fp) +
9692 ")";
9693 if (imgtype.image.dim == DimSubpassData)
9694 {
9695 if (msl_options.multiview)
9696 farg_str += " + gl_ViewIndex";
9697 else if (msl_options.arrayed_subpass_input)
9698 farg_str += " + gl_Layer";
9699 }
9700 }
9701 }
9702 else if (imgtype.image.dim == DimSubpassData)
9703 {
9704 if (msl_options.multiview)
9705 farg_str += ", gl_ViewIndex";
9706 else if (msl_options.arrayed_subpass_input)
9707 farg_str += ", gl_Layer";
9708 }
9709 }
9710
9711 // Depth compare reference value
9712 if (args.dref)
9713 {
9714 forward = forward && should_forward(args.dref);
9715 farg_str += ", ";
9716
9717 auto &dref_type = expression_type(args.dref);
9718
9719 string dref_expr;
9720 if (args.base.is_proj)
9721 dref_expr = join(to_enclosed_expression(args.dref), " / ",
9722 to_extract_component_expression(args.coord, alt_coord_component));
9723 else
9724 dref_expr = to_expression(args.dref);
9725
9726 if (sampling_type_needs_f32_conversion(dref_type))
9727 dref_expr = convert_to_f32(dref_expr, 1);
9728
9729 farg_str += dref_expr;
9730
9731 if (msl_options.is_macos() && (grad_x || grad_y))
9732 {
9733 // For sample compare, MSL does not support gradient2d for all targets (only iOS apparently according to docs).
9734 // However, the most common case here is to have a constant gradient of 0, as that is the only way to express
9735 // LOD == 0 in GLSL with sampler2DArrayShadow (cascaded shadow mapping).
9736 // We will detect a compile-time constant 0 value for gradient and promote that to level(0) on MSL.
9737 bool constant_zero_x = !grad_x || expression_is_constant_null(grad_x);
9738 bool constant_zero_y = !grad_y || expression_is_constant_null(grad_y);
9739 if (constant_zero_x && constant_zero_y)
9740 {
9741 lod = 0;
9742 grad_x = 0;
9743 grad_y = 0;
9744 farg_str += ", level(0)";
9745 }
9746 else if (!msl_options.supports_msl_version(2, 3))
9747 {
9748 SPIRV_CROSS_THROW("Using non-constant 0.0 gradient() qualifier for sample_compare. This is not "
9749 "supported on macOS prior to MSL 2.3.");
9750 }
9751 }
9752
9753 if (msl_options.is_macos() && bias)
9754 {
9755 // Bias is not supported either on macOS with sample_compare.
9756 // Verify it is compile-time zero, and drop the argument.
9757 if (expression_is_constant_null(bias))
9758 {
9759 bias = 0;
9760 }
9761 else if (!msl_options.supports_msl_version(2, 3))
9762 {
9763 SPIRV_CROSS_THROW("Using non-constant 0.0 bias() qualifier for sample_compare. This is not supported "
9764 "on macOS prior to MSL 2.3.");
9765 }
9766 }
9767 }
9768
9769 // LOD Options
9770 // Metal does not support LOD for 1D textures.
9771 if (bias && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D))
9772 {
9773 forward = forward && should_forward(bias);
9774 farg_str += ", bias(" + to_expression(bias) + ")";
9775 }
9776
9777 // Metal does not support LOD for 1D textures.
9778 if (lod && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D))
9779 {
9780 forward = forward && should_forward(lod);
9781 if (args.base.is_fetch)
9782 {
9783 farg_str += ", " + to_expression(lod);
9784 }
9785 else
9786 {
9787 farg_str += ", level(" + to_expression(lod) + ")";
9788 }
9789 }
9790 else if (args.base.is_fetch && !lod && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D) &&
9791 imgtype.image.dim != DimBuffer && !imgtype.image.ms && imgtype.image.sampled != 2)
9792 {
9793 // Lod argument is optional in OpImageFetch, but we require a LOD value, pick 0 as the default.
9794 // Check for sampled type as well, because is_fetch is also used for OpImageRead in MSL.
9795 farg_str += ", 0";
9796 }
9797
9798 // Metal does not support LOD for 1D textures.
9799 if ((grad_x || grad_y) && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D))
9800 {
9801 forward = forward && should_forward(grad_x);
9802 forward = forward && should_forward(grad_y);
9803 string grad_opt;
9804 switch (imgtype.image.dim)
9805 {
9806 case Dim1D:
9807 case Dim2D:
9808 grad_opt = "2d";
9809 break;
9810 case Dim3D:
9811 grad_opt = "3d";
9812 break;
9813 case DimCube:
9814 if (imgtype.image.arrayed && msl_options.emulate_cube_array)
9815 grad_opt = "2d";
9816 else
9817 grad_opt = "cube";
9818 break;
9819 default:
9820 grad_opt = "unsupported_gradient_dimension";
9821 break;
9822 }
9823 farg_str += ", gradient" + grad_opt + "(" + to_expression(grad_x) + ", " + to_expression(grad_y) + ")";
9824 }
9825
9826 if (args.min_lod)
9827 {
9828 if (!msl_options.supports_msl_version(2, 2))
9829 SPIRV_CROSS_THROW("min_lod_clamp() is only supported in MSL 2.2+ and up.");
9830
9831 forward = forward && should_forward(args.min_lod);
9832 farg_str += ", min_lod_clamp(" + to_expression(args.min_lod) + ")";
9833 }
9834
9835 // Add offsets
9836 string offset_expr;
9837 const SPIRType *offset_type = nullptr;
9838 if (args.coffset && !args.base.is_fetch)
9839 {
9840 forward = forward && should_forward(args.coffset);
9841 offset_expr = to_expression(args.coffset);
9842 offset_type = &expression_type(args.coffset);
9843 }
9844 else if (args.offset && !args.base.is_fetch)
9845 {
9846 forward = forward && should_forward(args.offset);
9847 offset_expr = to_expression(args.offset);
9848 offset_type = &expression_type(args.offset);
9849 }
9850
9851 if (!offset_expr.empty())
9852 {
9853 switch (imgtype.image.dim)
9854 {
9855 case Dim1D:
9856 if (!msl_options.texture_1D_as_2D)
9857 break;
9858 if (offset_type->vecsize > 1)
9859 offset_expr = enclose_expression(offset_expr) + ".x";
9860
9861 farg_str += join(", int2(", offset_expr, ", 0)");
9862 break;
9863
9864 case Dim2D:
9865 if (offset_type->vecsize > 2)
9866 offset_expr = enclose_expression(offset_expr) + ".xy";
9867
9868 farg_str += ", " + offset_expr;
9869 break;
9870
9871 case Dim3D:
9872 if (offset_type->vecsize > 3)
9873 offset_expr = enclose_expression(offset_expr) + ".xyz";
9874
9875 farg_str += ", " + offset_expr;
9876 break;
9877
9878 default:
9879 break;
9880 }
9881 }
9882
9883 if (args.component)
9884 {
9885 // If 2D has gather component, ensure it also has an offset arg
9886 if (imgtype.image.dim == Dim2D && offset_expr.empty())
9887 farg_str += ", int2(0)";
9888
9889 if (!msl_options.swizzle_texture_samples || is_dynamic_img_sampler)
9890 {
9891 forward = forward && should_forward(args.component);
9892
9893 uint32_t image_var = 0;
9894 if (const auto *combined = maybe_get<SPIRCombinedImageSampler>(img))
9895 {
9896 if (const auto *img_var = maybe_get_backing_variable(combined->image))
9897 image_var = img_var->self;
9898 }
9899 else if (const auto *var = maybe_get_backing_variable(img))
9900 {
9901 image_var = var->self;
9902 }
9903
9904 if (image_var == 0 || !image_is_comparison(expression_type(image_var), image_var))
9905 farg_str += ", " + to_component_argument(args.component);
9906 }
9907 }
9908
9909 if (args.sample)
9910 {
9911 forward = forward && should_forward(args.sample);
9912 farg_str += ", ";
9913 farg_str += to_expression(args.sample);
9914 }
9915
9916 *p_forward = forward;
9917
9918 return farg_str;
9919 }
9920
9921 // If the texture coordinates are floating point, invokes MSL round() function to round them.
round_fp_tex_coords(string tex_coords,bool coord_is_fp)9922 string CompilerMSL::round_fp_tex_coords(string tex_coords, bool coord_is_fp)
9923 {
9924 return coord_is_fp ? ("round(" + tex_coords + ")") : tex_coords;
9925 }
9926
9927 // Returns a string to use in an image sampling function argument.
9928 // The ID must be a scalar constant.
to_component_argument(uint32_t id)9929 string CompilerMSL::to_component_argument(uint32_t id)
9930 {
9931 uint32_t component_index = evaluate_constant_u32(id);
9932 switch (component_index)
9933 {
9934 case 0:
9935 return "component::x";
9936 case 1:
9937 return "component::y";
9938 case 2:
9939 return "component::z";
9940 case 3:
9941 return "component::w";
9942
9943 default:
9944 SPIRV_CROSS_THROW("The value (" + to_string(component_index) + ") of OpConstant ID " + to_string(id) +
9945 " is not a valid Component index, which must be one of 0, 1, 2, or 3.");
9946 }
9947 }
9948
9949 // Establish sampled image as expression object and assign the sampler to it.
emit_sampled_image_op(uint32_t result_type,uint32_t result_id,uint32_t image_id,uint32_t samp_id)9950 void CompilerMSL::emit_sampled_image_op(uint32_t result_type, uint32_t result_id, uint32_t image_id, uint32_t samp_id)
9951 {
9952 set<SPIRCombinedImageSampler>(result_id, result_type, image_id, samp_id);
9953 }
9954
to_texture_op(const Instruction & i,bool sparse,bool * forward,SmallVector<uint32_t> & inherited_expressions)9955 string CompilerMSL::to_texture_op(const Instruction &i, bool sparse, bool *forward,
9956 SmallVector<uint32_t> &inherited_expressions)
9957 {
9958 auto *ops = stream(i);
9959 uint32_t result_type_id = ops[0];
9960 uint32_t img = ops[2];
9961 auto &result_type = get<SPIRType>(result_type_id);
9962 auto op = static_cast<Op>(i.op);
9963 bool is_gather = (op == OpImageGather || op == OpImageDrefGather);
9964
9965 // Bypass pointers because we need the real image struct
9966 auto &type = expression_type(img);
9967 auto &imgtype = get<SPIRType>(type.self);
9968
9969 const MSLConstexprSampler *constexpr_sampler = nullptr;
9970 bool is_dynamic_img_sampler = false;
9971 if (auto *var = maybe_get_backing_variable(img))
9972 {
9973 constexpr_sampler = find_constexpr_sampler(var->basevariable ? var->basevariable : VariableID(var->self));
9974 is_dynamic_img_sampler = has_extended_decoration(var->self, SPIRVCrossDecorationDynamicImageSampler);
9975 }
9976
9977 string expr;
9978 if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable && !is_dynamic_img_sampler)
9979 {
9980 // If this needs sampler Y'CbCr conversion, we need to do some additional
9981 // processing.
9982 switch (constexpr_sampler->ycbcr_model)
9983 {
9984 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY:
9985 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_IDENTITY:
9986 // Default
9987 break;
9988 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_709:
9989 add_spv_func_and_recompile(SPVFuncImplConvertYCbCrBT709);
9990 expr += "spvConvertYCbCrBT709(";
9991 break;
9992 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_601:
9993 add_spv_func_and_recompile(SPVFuncImplConvertYCbCrBT601);
9994 expr += "spvConvertYCbCrBT601(";
9995 break;
9996 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_2020:
9997 add_spv_func_and_recompile(SPVFuncImplConvertYCbCrBT2020);
9998 expr += "spvConvertYCbCrBT2020(";
9999 break;
10000 default:
10001 SPIRV_CROSS_THROW("Invalid Y'CbCr model conversion.");
10002 }
10003
10004 if (constexpr_sampler->ycbcr_model != MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY)
10005 {
10006 switch (constexpr_sampler->ycbcr_range)
10007 {
10008 case MSL_SAMPLER_YCBCR_RANGE_ITU_FULL:
10009 add_spv_func_and_recompile(SPVFuncImplExpandITUFullRange);
10010 expr += "spvExpandITUFullRange(";
10011 break;
10012 case MSL_SAMPLER_YCBCR_RANGE_ITU_NARROW:
10013 add_spv_func_and_recompile(SPVFuncImplExpandITUNarrowRange);
10014 expr += "spvExpandITUNarrowRange(";
10015 break;
10016 default:
10017 SPIRV_CROSS_THROW("Invalid Y'CbCr range.");
10018 }
10019 }
10020 }
10021 else if (msl_options.swizzle_texture_samples && !is_gather && is_sampled_image_type(imgtype) &&
10022 !is_dynamic_img_sampler)
10023 {
10024 add_spv_func_and_recompile(SPVFuncImplTextureSwizzle);
10025 expr += "spvTextureSwizzle(";
10026 }
10027
10028 string inner_expr = CompilerGLSL::to_texture_op(i, sparse, forward, inherited_expressions);
10029
10030 if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable && !is_dynamic_img_sampler)
10031 {
10032 if (!constexpr_sampler->swizzle_is_identity())
10033 {
10034 static const char swizzle_names[] = "rgba";
10035 if (!constexpr_sampler->swizzle_has_one_or_zero())
10036 {
10037 // If we can, do it inline.
10038 expr += inner_expr + ".";
10039 for (uint32_t c = 0; c < 4; c++)
10040 {
10041 switch (constexpr_sampler->swizzle[c])
10042 {
10043 case MSL_COMPONENT_SWIZZLE_IDENTITY:
10044 expr += swizzle_names[c];
10045 break;
10046 case MSL_COMPONENT_SWIZZLE_R:
10047 case MSL_COMPONENT_SWIZZLE_G:
10048 case MSL_COMPONENT_SWIZZLE_B:
10049 case MSL_COMPONENT_SWIZZLE_A:
10050 expr += swizzle_names[constexpr_sampler->swizzle[c] - MSL_COMPONENT_SWIZZLE_R];
10051 break;
10052 default:
10053 SPIRV_CROSS_THROW("Invalid component swizzle.");
10054 }
10055 }
10056 }
10057 else
10058 {
10059 // Otherwise, we need to emit a temporary and swizzle that.
10060 uint32_t temp_id = ir.increase_bound_by(1);
10061 emit_op(result_type_id, temp_id, inner_expr, false);
10062 for (auto &inherit : inherited_expressions)
10063 inherit_expression_dependencies(temp_id, inherit);
10064 inherited_expressions.clear();
10065 inherited_expressions.push_back(temp_id);
10066
10067 switch (op)
10068 {
10069 case OpImageSampleDrefImplicitLod:
10070 case OpImageSampleImplicitLod:
10071 case OpImageSampleProjImplicitLod:
10072 case OpImageSampleProjDrefImplicitLod:
10073 register_control_dependent_expression(temp_id);
10074 break;
10075
10076 default:
10077 break;
10078 }
10079 expr += type_to_glsl(result_type) + "(";
10080 for (uint32_t c = 0; c < 4; c++)
10081 {
10082 switch (constexpr_sampler->swizzle[c])
10083 {
10084 case MSL_COMPONENT_SWIZZLE_IDENTITY:
10085 expr += to_expression(temp_id) + "." + swizzle_names[c];
10086 break;
10087 case MSL_COMPONENT_SWIZZLE_ZERO:
10088 expr += "0";
10089 break;
10090 case MSL_COMPONENT_SWIZZLE_ONE:
10091 expr += "1";
10092 break;
10093 case MSL_COMPONENT_SWIZZLE_R:
10094 case MSL_COMPONENT_SWIZZLE_G:
10095 case MSL_COMPONENT_SWIZZLE_B:
10096 case MSL_COMPONENT_SWIZZLE_A:
10097 expr += to_expression(temp_id) + "." +
10098 swizzle_names[constexpr_sampler->swizzle[c] - MSL_COMPONENT_SWIZZLE_R];
10099 break;
10100 default:
10101 SPIRV_CROSS_THROW("Invalid component swizzle.");
10102 }
10103 if (c < 3)
10104 expr += ", ";
10105 }
10106 expr += ")";
10107 }
10108 }
10109 else
10110 expr += inner_expr;
10111 if (constexpr_sampler->ycbcr_model != MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY)
10112 {
10113 expr += join(", ", constexpr_sampler->bpc, ")");
10114 if (constexpr_sampler->ycbcr_model != MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_IDENTITY)
10115 expr += ")";
10116 }
10117 }
10118 else
10119 {
10120 expr += inner_expr;
10121 if (msl_options.swizzle_texture_samples && !is_gather && is_sampled_image_type(imgtype) &&
10122 !is_dynamic_img_sampler)
10123 {
10124 // Add the swizzle constant from the swizzle buffer.
10125 expr += ", " + to_swizzle_expression(img) + ")";
10126 used_swizzle_buffer = true;
10127 }
10128 }
10129
10130 return expr;
10131 }
10132
create_swizzle(MSLComponentSwizzle swizzle)10133 static string create_swizzle(MSLComponentSwizzle swizzle)
10134 {
10135 switch (swizzle)
10136 {
10137 case MSL_COMPONENT_SWIZZLE_IDENTITY:
10138 return "spvSwizzle::none";
10139 case MSL_COMPONENT_SWIZZLE_ZERO:
10140 return "spvSwizzle::zero";
10141 case MSL_COMPONENT_SWIZZLE_ONE:
10142 return "spvSwizzle::one";
10143 case MSL_COMPONENT_SWIZZLE_R:
10144 return "spvSwizzle::red";
10145 case MSL_COMPONENT_SWIZZLE_G:
10146 return "spvSwizzle::green";
10147 case MSL_COMPONENT_SWIZZLE_B:
10148 return "spvSwizzle::blue";
10149 case MSL_COMPONENT_SWIZZLE_A:
10150 return "spvSwizzle::alpha";
10151 default:
10152 SPIRV_CROSS_THROW("Invalid component swizzle.");
10153 }
10154 }
10155
10156 // Returns a string representation of the ID, usable as a function arg.
10157 // Manufacture automatic sampler arg for SampledImage texture.
to_func_call_arg(const SPIRFunction::Parameter & arg,uint32_t id)10158 string CompilerMSL::to_func_call_arg(const SPIRFunction::Parameter &arg, uint32_t id)
10159 {
10160 string arg_str;
10161
10162 auto &type = expression_type(id);
10163 bool is_dynamic_img_sampler = has_extended_decoration(arg.id, SPIRVCrossDecorationDynamicImageSampler);
10164 // If the argument *itself* is a "dynamic" combined-image sampler, then we can just pass that around.
10165 bool arg_is_dynamic_img_sampler = has_extended_decoration(id, SPIRVCrossDecorationDynamicImageSampler);
10166 if (is_dynamic_img_sampler && !arg_is_dynamic_img_sampler)
10167 arg_str = join("spvDynamicImageSampler<", type_to_glsl(get<SPIRType>(type.image.type)), ">(");
10168
10169 auto *c = maybe_get<SPIRConstant>(id);
10170 if (msl_options.force_native_arrays && c && !get<SPIRType>(c->constant_type).array.empty())
10171 {
10172 // If we are passing a constant array directly to a function for some reason,
10173 // the callee will expect an argument in thread const address space
10174 // (since we can only bind to arrays with references in MSL).
10175 // To resolve this, we must emit a copy in this address space.
10176 // This kind of code gen should be rare enough that performance is not a real concern.
10177 // Inline the SPIR-V to avoid this kind of suboptimal codegen.
10178 //
10179 // We risk calling this inside a continue block (invalid code),
10180 // so just create a thread local copy in the current function.
10181 arg_str = join("_", id, "_array_copy");
10182 auto &constants = current_function->constant_arrays_needed_on_stack;
10183 auto itr = find(begin(constants), end(constants), ID(id));
10184 if (itr == end(constants))
10185 {
10186 force_recompile();
10187 constants.push_back(id);
10188 }
10189 }
10190 else
10191 arg_str += CompilerGLSL::to_func_call_arg(arg, id);
10192
10193 // Need to check the base variable in case we need to apply a qualified alias.
10194 uint32_t var_id = 0;
10195 auto *var = maybe_get<SPIRVariable>(id);
10196 if (var)
10197 var_id = var->basevariable;
10198
10199 if (!arg_is_dynamic_img_sampler)
10200 {
10201 auto *constexpr_sampler = find_constexpr_sampler(var_id ? var_id : id);
10202 if (type.basetype == SPIRType::SampledImage)
10203 {
10204 // Manufacture automatic plane args for multiplanar texture
10205 uint32_t planes = 1;
10206 if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
10207 {
10208 planes = constexpr_sampler->planes;
10209 // If this parameter isn't aliasing a global, then we need to use
10210 // the special "dynamic image-sampler" class to pass it--and we need
10211 // to use it for *every* non-alias parameter, in case a combined
10212 // image-sampler with a Y'CbCr conversion is passed. Hopefully, this
10213 // pathological case is so rare that it should never be hit in practice.
10214 if (!arg.alias_global_variable)
10215 add_spv_func_and_recompile(SPVFuncImplDynamicImageSampler);
10216 }
10217 for (uint32_t i = 1; i < planes; i++)
10218 arg_str += join(", ", CompilerGLSL::to_func_call_arg(arg, id), plane_name_suffix, i);
10219 // Manufacture automatic sampler arg if the arg is a SampledImage texture.
10220 if (type.image.dim != DimBuffer)
10221 arg_str += ", " + to_sampler_expression(var_id ? var_id : id);
10222
10223 // Add sampler Y'CbCr conversion info if we have it
10224 if (is_dynamic_img_sampler && constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
10225 {
10226 SmallVector<string> samp_args;
10227
10228 switch (constexpr_sampler->resolution)
10229 {
10230 case MSL_FORMAT_RESOLUTION_444:
10231 // Default
10232 break;
10233 case MSL_FORMAT_RESOLUTION_422:
10234 samp_args.push_back("spvFormatResolution::_422");
10235 break;
10236 case MSL_FORMAT_RESOLUTION_420:
10237 samp_args.push_back("spvFormatResolution::_420");
10238 break;
10239 default:
10240 SPIRV_CROSS_THROW("Invalid format resolution.");
10241 }
10242
10243 if (constexpr_sampler->chroma_filter != MSL_SAMPLER_FILTER_NEAREST)
10244 samp_args.push_back("spvChromaFilter::linear");
10245
10246 if (constexpr_sampler->x_chroma_offset != MSL_CHROMA_LOCATION_COSITED_EVEN)
10247 samp_args.push_back("spvXChromaLocation::midpoint");
10248 if (constexpr_sampler->y_chroma_offset != MSL_CHROMA_LOCATION_COSITED_EVEN)
10249 samp_args.push_back("spvYChromaLocation::midpoint");
10250 switch (constexpr_sampler->ycbcr_model)
10251 {
10252 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY:
10253 // Default
10254 break;
10255 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_IDENTITY:
10256 samp_args.push_back("spvYCbCrModelConversion::ycbcr_identity");
10257 break;
10258 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_709:
10259 samp_args.push_back("spvYCbCrModelConversion::ycbcr_bt_709");
10260 break;
10261 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_601:
10262 samp_args.push_back("spvYCbCrModelConversion::ycbcr_bt_601");
10263 break;
10264 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_2020:
10265 samp_args.push_back("spvYCbCrModelConversion::ycbcr_bt_2020");
10266 break;
10267 default:
10268 SPIRV_CROSS_THROW("Invalid Y'CbCr model conversion.");
10269 }
10270 if (constexpr_sampler->ycbcr_range != MSL_SAMPLER_YCBCR_RANGE_ITU_FULL)
10271 samp_args.push_back("spvYCbCrRange::itu_narrow");
10272 samp_args.push_back(join("spvComponentBits(", constexpr_sampler->bpc, ")"));
10273 arg_str += join(", spvYCbCrSampler(", merge(samp_args), ")");
10274 }
10275 }
10276
10277 if (is_dynamic_img_sampler && constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
10278 arg_str += join(", (uint(", create_swizzle(constexpr_sampler->swizzle[3]), ") << 24) | (uint(",
10279 create_swizzle(constexpr_sampler->swizzle[2]), ") << 16) | (uint(",
10280 create_swizzle(constexpr_sampler->swizzle[1]), ") << 8) | uint(",
10281 create_swizzle(constexpr_sampler->swizzle[0]), ")");
10282 else if (msl_options.swizzle_texture_samples && has_sampled_images && is_sampled_image_type(type))
10283 arg_str += ", " + to_swizzle_expression(var_id ? var_id : id);
10284
10285 if (buffers_requiring_array_length.count(var_id))
10286 arg_str += ", " + to_buffer_size_expression(var_id ? var_id : id);
10287
10288 if (is_dynamic_img_sampler)
10289 arg_str += ")";
10290 }
10291
10292 // Emulate texture2D atomic operations
10293 auto *backing_var = maybe_get_backing_variable(var_id);
10294 if (backing_var && atomic_image_vars.count(backing_var->self))
10295 {
10296 arg_str += ", " + to_expression(var_id) + "_atomic";
10297 }
10298
10299 return arg_str;
10300 }
10301
10302 // If the ID represents a sampled image that has been assigned a sampler already,
10303 // generate an expression for the sampler, otherwise generate a fake sampler name
10304 // by appending a suffix to the expression constructed from the ID.
to_sampler_expression(uint32_t id)10305 string CompilerMSL::to_sampler_expression(uint32_t id)
10306 {
10307 auto *combined = maybe_get<SPIRCombinedImageSampler>(id);
10308 auto expr = to_expression(combined ? combined->image : VariableID(id));
10309 auto index = expr.find_first_of('[');
10310
10311 uint32_t samp_id = 0;
10312 if (combined)
10313 samp_id = combined->sampler;
10314
10315 if (index == string::npos)
10316 return samp_id ? to_expression(samp_id) : expr + sampler_name_suffix;
10317 else
10318 {
10319 auto image_expr = expr.substr(0, index);
10320 auto array_expr = expr.substr(index);
10321 return samp_id ? to_expression(samp_id) : (image_expr + sampler_name_suffix + array_expr);
10322 }
10323 }
10324
to_swizzle_expression(uint32_t id)10325 string CompilerMSL::to_swizzle_expression(uint32_t id)
10326 {
10327 auto *combined = maybe_get<SPIRCombinedImageSampler>(id);
10328
10329 auto expr = to_expression(combined ? combined->image : VariableID(id));
10330 auto index = expr.find_first_of('[');
10331
10332 // If an image is part of an argument buffer translate this to a legal identifier.
10333 string::size_type period = 0;
10334 while ((period = expr.find_first_of('.', period)) != string::npos && period < index)
10335 expr[period] = '_';
10336
10337 if (index == string::npos)
10338 return expr + swizzle_name_suffix;
10339 else
10340 {
10341 auto image_expr = expr.substr(0, index);
10342 auto array_expr = expr.substr(index);
10343 return image_expr + swizzle_name_suffix + array_expr;
10344 }
10345 }
10346
to_buffer_size_expression(uint32_t id)10347 string CompilerMSL::to_buffer_size_expression(uint32_t id)
10348 {
10349 auto expr = to_expression(id);
10350 auto index = expr.find_first_of('[');
10351
10352 // This is quite crude, but we need to translate the reference name (*spvDescriptorSetN.name) to
10353 // the pointer expression spvDescriptorSetN.name to make a reasonable expression here.
10354 // This only happens if we have argument buffers and we are using OpArrayLength on a lone SSBO in that set.
10355 if (expr.size() >= 3 && expr[0] == '(' && expr[1] == '*')
10356 expr = address_of_expression(expr);
10357
10358 // If a buffer is part of an argument buffer translate this to a legal identifier.
10359 for (auto &c : expr)
10360 if (c == '.')
10361 c = '_';
10362
10363 if (index == string::npos)
10364 return expr + buffer_size_name_suffix;
10365 else
10366 {
10367 auto buffer_expr = expr.substr(0, index);
10368 auto array_expr = expr.substr(index);
10369 return buffer_expr + buffer_size_name_suffix + array_expr;
10370 }
10371 }
10372
10373 // Checks whether the type is a Block all of whose members have DecorationPatch.
is_patch_block(const SPIRType & type)10374 bool CompilerMSL::is_patch_block(const SPIRType &type)
10375 {
10376 if (!has_decoration(type.self, DecorationBlock))
10377 return false;
10378
10379 for (uint32_t i = 0; i < type.member_types.size(); i++)
10380 {
10381 if (!has_member_decoration(type.self, i, DecorationPatch))
10382 return false;
10383 }
10384
10385 return true;
10386 }
10387
10388 // Checks whether the ID is a row_major matrix that requires conversion before use
is_non_native_row_major_matrix(uint32_t id)10389 bool CompilerMSL::is_non_native_row_major_matrix(uint32_t id)
10390 {
10391 auto *e = maybe_get<SPIRExpression>(id);
10392 if (e)
10393 return e->need_transpose;
10394 else
10395 return has_decoration(id, DecorationRowMajor);
10396 }
10397
10398 // Checks whether the member is a row_major matrix that requires conversion before use
member_is_non_native_row_major_matrix(const SPIRType & type,uint32_t index)10399 bool CompilerMSL::member_is_non_native_row_major_matrix(const SPIRType &type, uint32_t index)
10400 {
10401 return has_member_decoration(type.self, index, DecorationRowMajor);
10402 }
10403
convert_row_major_matrix(string exp_str,const SPIRType & exp_type,uint32_t physical_type_id,bool is_packed)10404 string CompilerMSL::convert_row_major_matrix(string exp_str, const SPIRType &exp_type, uint32_t physical_type_id,
10405 bool is_packed)
10406 {
10407 if (!is_matrix(exp_type))
10408 {
10409 return CompilerGLSL::convert_row_major_matrix(move(exp_str), exp_type, physical_type_id, is_packed);
10410 }
10411 else
10412 {
10413 strip_enclosed_expression(exp_str);
10414 if (physical_type_id != 0 || is_packed)
10415 exp_str = unpack_expression_type(exp_str, exp_type, physical_type_id, is_packed, true);
10416 return join("transpose(", exp_str, ")");
10417 }
10418 }
10419
10420 // Called automatically at the end of the entry point function
emit_fixup()10421 void CompilerMSL::emit_fixup()
10422 {
10423 if (is_vertex_like_shader() && stage_out_var_id && !qual_pos_var_name.empty() && !capture_output_to_buffer)
10424 {
10425 if (options.vertex.fixup_clipspace)
10426 statement(qual_pos_var_name, ".z = (", qual_pos_var_name, ".z + ", qual_pos_var_name,
10427 ".w) * 0.5; // Adjust clip-space for Metal");
10428
10429 if (options.vertex.flip_vert_y)
10430 statement(qual_pos_var_name, ".y = -(", qual_pos_var_name, ".y);", " // Invert Y-axis for Metal");
10431 }
10432 }
10433
10434 // Return a string defining a structure member, with padding and packing.
to_struct_member(const SPIRType & type,uint32_t member_type_id,uint32_t index,const string & qualifier)10435 string CompilerMSL::to_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index,
10436 const string &qualifier)
10437 {
10438 if (member_is_remapped_physical_type(type, index))
10439 member_type_id = get_extended_member_decoration(type.self, index, SPIRVCrossDecorationPhysicalTypeID);
10440 auto &physical_type = get<SPIRType>(member_type_id);
10441
10442 // If this member is packed, mark it as so.
10443 string pack_pfx;
10444
10445 // Allow Metal to use the array<T> template to make arrays a value type
10446 uint32_t orig_id = 0;
10447 if (has_extended_member_decoration(type.self, index, SPIRVCrossDecorationInterfaceOrigID))
10448 orig_id = get_extended_member_decoration(type.self, index, SPIRVCrossDecorationInterfaceOrigID);
10449
10450 bool row_major = false;
10451 if (is_matrix(physical_type))
10452 row_major = has_member_decoration(type.self, index, DecorationRowMajor);
10453
10454 SPIRType row_major_physical_type;
10455 const SPIRType *declared_type = &physical_type;
10456
10457 // If a struct is being declared with physical layout,
10458 // do not use array<T> wrappers.
10459 // This avoids a lot of complicated cases with packed vectors and matrices,
10460 // and generally we cannot copy full arrays in and out of buffers into Function
10461 // address space.
10462 // Array of resources should also be declared as builtin arrays.
10463 if (has_member_decoration(type.self, index, DecorationOffset))
10464 is_using_builtin_array = true;
10465 else if (has_extended_member_decoration(type.self, index, SPIRVCrossDecorationResourceIndexPrimary))
10466 is_using_builtin_array = true;
10467
10468 if (member_is_packed_physical_type(type, index))
10469 {
10470 // If we're packing a matrix, output an appropriate typedef
10471 if (physical_type.basetype == SPIRType::Struct)
10472 {
10473 SPIRV_CROSS_THROW("Cannot emit a packed struct currently.");
10474 }
10475 else if (is_matrix(physical_type))
10476 {
10477 uint32_t rows = physical_type.vecsize;
10478 uint32_t cols = physical_type.columns;
10479 pack_pfx = "packed_";
10480 if (row_major)
10481 {
10482 // These are stored transposed.
10483 rows = physical_type.columns;
10484 cols = physical_type.vecsize;
10485 pack_pfx = "packed_rm_";
10486 }
10487 string base_type = physical_type.width == 16 ? "half" : "float";
10488 string td_line = "typedef ";
10489 td_line += "packed_" + base_type + to_string(rows);
10490 td_line += " " + pack_pfx;
10491 // Use the actual matrix size here.
10492 td_line += base_type + to_string(physical_type.columns) + "x" + to_string(physical_type.vecsize);
10493 td_line += "[" + to_string(cols) + "]";
10494 td_line += ";";
10495 add_typedef_line(td_line);
10496 }
10497 else if (!is_scalar(physical_type)) // scalar type is already packed.
10498 pack_pfx = "packed_";
10499 }
10500 else if (row_major)
10501 {
10502 // Need to declare type with flipped vecsize/columns.
10503 row_major_physical_type = physical_type;
10504 swap(row_major_physical_type.vecsize, row_major_physical_type.columns);
10505 declared_type = &row_major_physical_type;
10506 }
10507
10508 // Very specifically, image load-store in argument buffers are disallowed on MSL on iOS.
10509 if (msl_options.is_ios() && physical_type.basetype == SPIRType::Image && physical_type.image.sampled == 2)
10510 {
10511 if (!has_decoration(orig_id, DecorationNonWritable))
10512 SPIRV_CROSS_THROW("Writable images are not allowed in argument buffers on iOS.");
10513 }
10514
10515 // Array information is baked into these types.
10516 string array_type;
10517 if (physical_type.basetype != SPIRType::Image && physical_type.basetype != SPIRType::Sampler &&
10518 physical_type.basetype != SPIRType::SampledImage)
10519 {
10520 BuiltIn builtin = BuiltInMax;
10521
10522 // Special handling. In [[stage_out]] or [[stage_in]] blocks,
10523 // we need flat arrays, but if we're somehow declaring gl_PerVertex for constant array reasons, we want
10524 // template array types to be declared.
10525 bool is_ib_in_out =
10526 ((stage_out_var_id && get_stage_out_struct_type().self == type.self &&
10527 variable_storage_requires_stage_io(StorageClassOutput)) ||
10528 (stage_in_var_id && get_stage_in_struct_type().self == type.self &&
10529 variable_storage_requires_stage_io(StorageClassInput)));
10530 if (is_ib_in_out && is_member_builtin(type, index, &builtin))
10531 is_using_builtin_array = true;
10532 array_type = type_to_array_glsl(physical_type);
10533 }
10534
10535 auto result = join(pack_pfx, type_to_glsl(*declared_type, orig_id), " ", qualifier, to_member_name(type, index),
10536 member_attribute_qualifier(type, index), array_type, ";");
10537
10538 is_using_builtin_array = false;
10539 return result;
10540 }
10541
10542 // Emit a structure member, padding and packing to maintain the correct memeber alignments.
emit_struct_member(const SPIRType & type,uint32_t member_type_id,uint32_t index,const string & qualifier,uint32_t)10543 void CompilerMSL::emit_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index,
10544 const string &qualifier, uint32_t)
10545 {
10546 // If this member requires padding to maintain its declared offset, emit a dummy padding member before it.
10547 if (has_extended_member_decoration(type.self, index, SPIRVCrossDecorationPaddingTarget))
10548 {
10549 uint32_t pad_len = get_extended_member_decoration(type.self, index, SPIRVCrossDecorationPaddingTarget);
10550 statement("char _m", index, "_pad", "[", pad_len, "];");
10551 }
10552
10553 // Handle HLSL-style 0-based vertex/instance index.
10554 builtin_declaration = true;
10555 statement(to_struct_member(type, member_type_id, index, qualifier));
10556 builtin_declaration = false;
10557 }
10558
emit_struct_padding_target(const SPIRType & type)10559 void CompilerMSL::emit_struct_padding_target(const SPIRType &type)
10560 {
10561 uint32_t struct_size = get_declared_struct_size_msl(type, true, true);
10562 uint32_t target_size = get_extended_decoration(type.self, SPIRVCrossDecorationPaddingTarget);
10563 if (target_size < struct_size)
10564 SPIRV_CROSS_THROW("Cannot pad with negative bytes.");
10565 else if (target_size > struct_size)
10566 statement("char _m0_final_padding[", target_size - struct_size, "];");
10567 }
10568
10569 // Return a MSL qualifier for the specified function attribute member
member_attribute_qualifier(const SPIRType & type,uint32_t index)10570 string CompilerMSL::member_attribute_qualifier(const SPIRType &type, uint32_t index)
10571 {
10572 auto &execution = get_entry_point();
10573
10574 uint32_t mbr_type_id = type.member_types[index];
10575 auto &mbr_type = get<SPIRType>(mbr_type_id);
10576
10577 BuiltIn builtin = BuiltInMax;
10578 bool is_builtin = is_member_builtin(type, index, &builtin);
10579
10580 if (has_extended_member_decoration(type.self, index, SPIRVCrossDecorationResourceIndexPrimary))
10581 {
10582 string quals = join(
10583 " [[id(", get_extended_member_decoration(type.self, index, SPIRVCrossDecorationResourceIndexPrimary), ")");
10584 if (interlocked_resources.count(
10585 get_extended_member_decoration(type.self, index, SPIRVCrossDecorationInterfaceOrigID)))
10586 quals += ", raster_order_group(0)";
10587 quals += "]]";
10588 return quals;
10589 }
10590
10591 // Vertex function inputs
10592 if (execution.model == ExecutionModelVertex && type.storage == StorageClassInput)
10593 {
10594 if (is_builtin)
10595 {
10596 switch (builtin)
10597 {
10598 case BuiltInVertexId:
10599 case BuiltInVertexIndex:
10600 case BuiltInBaseVertex:
10601 case BuiltInInstanceId:
10602 case BuiltInInstanceIndex:
10603 case BuiltInBaseInstance:
10604 if (msl_options.vertex_for_tessellation)
10605 return "";
10606 return string(" [[") + builtin_qualifier(builtin) + "]]";
10607
10608 case BuiltInDrawIndex:
10609 SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
10610
10611 default:
10612 return "";
10613 }
10614 }
10615
10616 uint32_t locn;
10617 if (is_builtin)
10618 locn = get_or_allocate_builtin_input_member_location(builtin, type.self, index);
10619 else
10620 locn = get_member_location(type.self, index);
10621
10622 if (locn != k_unknown_location)
10623 return string(" [[attribute(") + convert_to_string(locn) + ")]]";
10624 }
10625
10626 // Vertex and tessellation evaluation function outputs
10627 if (((execution.model == ExecutionModelVertex && !msl_options.vertex_for_tessellation) ||
10628 execution.model == ExecutionModelTessellationEvaluation) &&
10629 type.storage == StorageClassOutput)
10630 {
10631 if (is_builtin)
10632 {
10633 switch (builtin)
10634 {
10635 case BuiltInPointSize:
10636 // Only mark the PointSize builtin if really rendering points.
10637 // Some shaders may include a PointSize builtin even when used to render
10638 // non-point topologies, and Metal will reject this builtin when compiling
10639 // the shader into a render pipeline that uses a non-point topology.
10640 return msl_options.enable_point_size_builtin ? (string(" [[") + builtin_qualifier(builtin) + "]]") : "";
10641
10642 case BuiltInViewportIndex:
10643 if (!msl_options.supports_msl_version(2, 0))
10644 SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
10645 /* fallthrough */
10646 case BuiltInPosition:
10647 case BuiltInLayer:
10648 return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
10649
10650 case BuiltInClipDistance:
10651 if (has_member_decoration(type.self, index, DecorationIndex))
10652 return join(" [[user(clip", get_member_decoration(type.self, index, DecorationIndex), ")]]");
10653 else
10654 return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
10655
10656 case BuiltInCullDistance:
10657 if (has_member_decoration(type.self, index, DecorationIndex))
10658 return join(" [[user(cull", get_member_decoration(type.self, index, DecorationIndex), ")]]");
10659 else
10660 return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
10661
10662 default:
10663 return "";
10664 }
10665 }
10666 uint32_t comp;
10667 uint32_t locn = get_member_location(type.self, index, &comp);
10668 if (locn != k_unknown_location)
10669 {
10670 if (comp != k_unknown_component)
10671 return string(" [[user(locn") + convert_to_string(locn) + "_" + convert_to_string(comp) + ")]]";
10672 else
10673 return string(" [[user(locn") + convert_to_string(locn) + ")]]";
10674 }
10675 }
10676
10677 // Tessellation control function inputs
10678 if (execution.model == ExecutionModelTessellationControl && type.storage == StorageClassInput)
10679 {
10680 if (is_builtin)
10681 {
10682 switch (builtin)
10683 {
10684 case BuiltInInvocationId:
10685 case BuiltInPrimitiveId:
10686 if (msl_options.multi_patch_workgroup)
10687 return "";
10688 return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
10689 case BuiltInSubgroupLocalInvocationId: // FIXME: Should work in any stage
10690 case BuiltInSubgroupSize: // FIXME: Should work in any stage
10691 if (msl_options.emulate_subgroups)
10692 return "";
10693 return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
10694 case BuiltInPatchVertices:
10695 return "";
10696 // Others come from stage input.
10697 default:
10698 break;
10699 }
10700 }
10701 if (msl_options.multi_patch_workgroup)
10702 return "";
10703
10704 uint32_t locn;
10705 if (is_builtin)
10706 locn = get_or_allocate_builtin_input_member_location(builtin, type.self, index);
10707 else
10708 locn = get_member_location(type.self, index);
10709
10710 if (locn != k_unknown_location)
10711 return string(" [[attribute(") + convert_to_string(locn) + ")]]";
10712 }
10713
10714 // Tessellation control function outputs
10715 if (execution.model == ExecutionModelTessellationControl && type.storage == StorageClassOutput)
10716 {
10717 // For this type of shader, we always arrange for it to capture its
10718 // output to a buffer. For this reason, qualifiers are irrelevant here.
10719 return "";
10720 }
10721
10722 // Tessellation evaluation function inputs
10723 if (execution.model == ExecutionModelTessellationEvaluation && type.storage == StorageClassInput)
10724 {
10725 if (is_builtin)
10726 {
10727 switch (builtin)
10728 {
10729 case BuiltInPrimitiveId:
10730 case BuiltInTessCoord:
10731 return string(" [[") + builtin_qualifier(builtin) + "]]";
10732 case BuiltInPatchVertices:
10733 return "";
10734 // Others come from stage input.
10735 default:
10736 break;
10737 }
10738 }
10739 // The special control point array must not be marked with an attribute.
10740 if (get_type(type.member_types[index]).basetype == SPIRType::ControlPointArray)
10741 return "";
10742
10743 uint32_t locn;
10744 if (is_builtin)
10745 locn = get_or_allocate_builtin_input_member_location(builtin, type.self, index);
10746 else
10747 locn = get_member_location(type.self, index);
10748
10749 if (locn != k_unknown_location)
10750 return string(" [[attribute(") + convert_to_string(locn) + ")]]";
10751 }
10752
10753 // Tessellation evaluation function outputs were handled above.
10754
10755 // Fragment function inputs
10756 if (execution.model == ExecutionModelFragment && type.storage == StorageClassInput)
10757 {
10758 string quals;
10759 if (is_builtin)
10760 {
10761 switch (builtin)
10762 {
10763 case BuiltInViewIndex:
10764 if (!msl_options.multiview || !msl_options.multiview_layered_rendering)
10765 break;
10766 /* fallthrough */
10767 case BuiltInFrontFacing:
10768 case BuiltInPointCoord:
10769 case BuiltInFragCoord:
10770 case BuiltInSampleId:
10771 case BuiltInSampleMask:
10772 case BuiltInLayer:
10773 case BuiltInBaryCoordNV:
10774 case BuiltInBaryCoordNoPerspNV:
10775 quals = builtin_qualifier(builtin);
10776 break;
10777
10778 case BuiltInClipDistance:
10779 return join(" [[user(clip", get_member_decoration(type.self, index, DecorationIndex), ")]]");
10780 case BuiltInCullDistance:
10781 return join(" [[user(cull", get_member_decoration(type.self, index, DecorationIndex), ")]]");
10782
10783 default:
10784 break;
10785 }
10786 }
10787 else
10788 {
10789 uint32_t comp;
10790 uint32_t locn = get_member_location(type.self, index, &comp);
10791 if (locn != k_unknown_location)
10792 {
10793 // For user-defined attributes, this is fine. From Vulkan spec:
10794 // A user-defined output variable is considered to match an input variable in the subsequent stage if
10795 // the two variables are declared with the same Location and Component decoration and match in type
10796 // and decoration, except that interpolation decorations are not required to match. For the purposes
10797 // of interface matching, variables declared without a Component decoration are considered to have a
10798 // Component decoration of zero.
10799
10800 if (comp != k_unknown_component && comp != 0)
10801 quals = string("user(locn") + convert_to_string(locn) + "_" + convert_to_string(comp) + ")";
10802 else
10803 quals = string("user(locn") + convert_to_string(locn) + ")";
10804 }
10805 }
10806
10807 if (builtin == BuiltInBaryCoordNV || builtin == BuiltInBaryCoordNoPerspNV)
10808 {
10809 if (has_member_decoration(type.self, index, DecorationFlat) ||
10810 has_member_decoration(type.self, index, DecorationCentroid) ||
10811 has_member_decoration(type.self, index, DecorationSample) ||
10812 has_member_decoration(type.self, index, DecorationNoPerspective))
10813 {
10814 // NoPerspective is baked into the builtin type.
10815 SPIRV_CROSS_THROW(
10816 "Flat, Centroid, Sample, NoPerspective decorations are not supported for BaryCoord inputs.");
10817 }
10818 }
10819
10820 // Don't bother decorating integers with the 'flat' attribute; it's
10821 // the default (in fact, the only option). Also don't bother with the
10822 // FragCoord builtin; it's always noperspective on Metal.
10823 if (!type_is_integral(mbr_type) && (!is_builtin || builtin != BuiltInFragCoord))
10824 {
10825 if (has_member_decoration(type.self, index, DecorationFlat))
10826 {
10827 if (!quals.empty())
10828 quals += ", ";
10829 quals += "flat";
10830 }
10831 else if (has_member_decoration(type.self, index, DecorationCentroid))
10832 {
10833 if (!quals.empty())
10834 quals += ", ";
10835 if (has_member_decoration(type.self, index, DecorationNoPerspective))
10836 quals += "centroid_no_perspective";
10837 else
10838 quals += "centroid_perspective";
10839 }
10840 else if (has_member_decoration(type.self, index, DecorationSample))
10841 {
10842 if (!quals.empty())
10843 quals += ", ";
10844 if (has_member_decoration(type.self, index, DecorationNoPerspective))
10845 quals += "sample_no_perspective";
10846 else
10847 quals += "sample_perspective";
10848 }
10849 else if (has_member_decoration(type.self, index, DecorationNoPerspective))
10850 {
10851 if (!quals.empty())
10852 quals += ", ";
10853 quals += "center_no_perspective";
10854 }
10855 }
10856
10857 if (!quals.empty())
10858 return " [[" + quals + "]]";
10859 }
10860
10861 // Fragment function outputs
10862 if (execution.model == ExecutionModelFragment && type.storage == StorageClassOutput)
10863 {
10864 if (is_builtin)
10865 {
10866 switch (builtin)
10867 {
10868 case BuiltInFragStencilRefEXT:
10869 // Similar to PointSize, only mark FragStencilRef if there's a stencil buffer.
10870 // Some shaders may include a FragStencilRef builtin even when used to render
10871 // without a stencil attachment, and Metal will reject this builtin
10872 // when compiling the shader into a render pipeline that does not set
10873 // stencilAttachmentPixelFormat.
10874 if (!msl_options.enable_frag_stencil_ref_builtin)
10875 return "";
10876 if (!msl_options.supports_msl_version(2, 1))
10877 SPIRV_CROSS_THROW("Stencil export only supported in MSL 2.1 and up.");
10878 return string(" [[") + builtin_qualifier(builtin) + "]]";
10879
10880 case BuiltInFragDepth:
10881 // Ditto FragDepth.
10882 if (!msl_options.enable_frag_depth_builtin)
10883 return "";
10884 /* fallthrough */
10885 case BuiltInSampleMask:
10886 return string(" [[") + builtin_qualifier(builtin) + "]]";
10887
10888 default:
10889 return "";
10890 }
10891 }
10892 uint32_t locn = get_member_location(type.self, index);
10893 // Metal will likely complain about missing color attachments, too.
10894 if (locn != k_unknown_location && !(msl_options.enable_frag_output_mask & (1 << locn)))
10895 return "";
10896 if (locn != k_unknown_location && has_member_decoration(type.self, index, DecorationIndex))
10897 return join(" [[color(", locn, "), index(", get_member_decoration(type.self, index, DecorationIndex),
10898 ")]]");
10899 else if (locn != k_unknown_location)
10900 return join(" [[color(", locn, ")]]");
10901 else if (has_member_decoration(type.self, index, DecorationIndex))
10902 return join(" [[index(", get_member_decoration(type.self, index, DecorationIndex), ")]]");
10903 else
10904 return "";
10905 }
10906
10907 // Compute function inputs
10908 if (execution.model == ExecutionModelGLCompute && type.storage == StorageClassInput)
10909 {
10910 if (is_builtin)
10911 {
10912 switch (builtin)
10913 {
10914 case BuiltInNumSubgroups:
10915 case BuiltInSubgroupId:
10916 case BuiltInSubgroupLocalInvocationId: // FIXME: Should work in any stage
10917 case BuiltInSubgroupSize: // FIXME: Should work in any stage
10918 if (msl_options.emulate_subgroups)
10919 break;
10920 /* fallthrough */
10921 case BuiltInGlobalInvocationId:
10922 case BuiltInWorkgroupId:
10923 case BuiltInNumWorkgroups:
10924 case BuiltInLocalInvocationId:
10925 case BuiltInLocalInvocationIndex:
10926 return string(" [[") + builtin_qualifier(builtin) + "]]";
10927
10928 default:
10929 return "";
10930 }
10931 }
10932 }
10933
10934 return "";
10935 }
10936
10937 // Returns the location decoration of the member with the specified index in the specified type.
10938 // If the location of the member has been explicitly set, that location is used. If not, this
10939 // function assumes the members are ordered in their location order, and simply returns the
10940 // index as the location.
get_member_location(uint32_t type_id,uint32_t index,uint32_t * comp) const10941 uint32_t CompilerMSL::get_member_location(uint32_t type_id, uint32_t index, uint32_t *comp) const
10942 {
10943 if (comp)
10944 {
10945 if (has_member_decoration(type_id, index, DecorationComponent))
10946 *comp = get_member_decoration(type_id, index, DecorationComponent);
10947 else
10948 *comp = k_unknown_component;
10949 }
10950
10951 if (has_member_decoration(type_id, index, DecorationLocation))
10952 return get_member_decoration(type_id, index, DecorationLocation);
10953 else
10954 return k_unknown_location;
10955 }
10956
get_or_allocate_builtin_input_member_location(spv::BuiltIn builtin,uint32_t type_id,uint32_t index,uint32_t * comp)10957 uint32_t CompilerMSL::get_or_allocate_builtin_input_member_location(spv::BuiltIn builtin,
10958 uint32_t type_id, uint32_t index,
10959 uint32_t *comp)
10960 {
10961 uint32_t loc = get_member_location(type_id, index, comp);
10962 if (loc != k_unknown_location)
10963 return loc;
10964
10965 if (comp)
10966 *comp = k_unknown_component;
10967
10968 // Late allocation. Find a location which is unused by the application.
10969 // This can happen for built-in inputs in tessellation which are mixed and matched with user inputs.
10970 auto &mbr_type = get<SPIRType>(get<SPIRType>(type_id).member_types[index]);
10971 uint32_t count = type_to_location_count(mbr_type);
10972
10973 loc = 0;
10974
10975 const auto location_range_in_use = [this](uint32_t location, uint32_t location_count) -> bool {
10976 for (uint32_t i = 0; i < location_count; i++)
10977 if (location_inputs_in_use.count(location + i) != 0)
10978 return true;
10979 return false;
10980 };
10981
10982 while (location_range_in_use(loc, count))
10983 loc++;
10984
10985 set_member_decoration(type_id, index, DecorationLocation, loc);
10986
10987 // Triangle tess level inputs are shared in one packed float4,
10988 // mark both builtins as sharing one location.
10989 if (get_execution_mode_bitset().get(ExecutionModeTriangles) &&
10990 (builtin == BuiltInTessLevelInner || builtin == BuiltInTessLevelOuter))
10991 {
10992 builtin_to_automatic_input_location[BuiltInTessLevelInner] = loc;
10993 builtin_to_automatic_input_location[BuiltInTessLevelOuter] = loc;
10994 }
10995 else
10996 builtin_to_automatic_input_location[builtin] = loc;
10997
10998 mark_location_as_used_by_shader(loc, mbr_type, StorageClassInput, true);
10999 return loc;
11000 }
11001
11002 // Returns the type declaration for a function, including the
11003 // entry type if the current function is the entry point function
func_type_decl(SPIRType & type)11004 string CompilerMSL::func_type_decl(SPIRType &type)
11005 {
11006 // The regular function return type. If not processing the entry point function, that's all we need
11007 string return_type = type_to_glsl(type) + type_to_array_glsl(type);
11008 if (!processing_entry_point)
11009 return return_type;
11010
11011 // If an outgoing interface block has been defined, and it should be returned, override the entry point return type
11012 bool ep_should_return_output = !get_is_rasterization_disabled();
11013 if (stage_out_var_id && ep_should_return_output)
11014 return_type = type_to_glsl(get_stage_out_struct_type()) + type_to_array_glsl(type);
11015
11016 // Prepend a entry type, based on the execution model
11017 string entry_type;
11018 auto &execution = get_entry_point();
11019 switch (execution.model)
11020 {
11021 case ExecutionModelVertex:
11022 if (msl_options.vertex_for_tessellation && !msl_options.supports_msl_version(1, 2))
11023 SPIRV_CROSS_THROW("Tessellation requires Metal 1.2.");
11024 entry_type = msl_options.vertex_for_tessellation ? "kernel" : "vertex";
11025 break;
11026 case ExecutionModelTessellationEvaluation:
11027 if (!msl_options.supports_msl_version(1, 2))
11028 SPIRV_CROSS_THROW("Tessellation requires Metal 1.2.");
11029 if (execution.flags.get(ExecutionModeIsolines))
11030 SPIRV_CROSS_THROW("Metal does not support isoline tessellation.");
11031 if (msl_options.is_ios())
11032 entry_type =
11033 join("[[ patch(", execution.flags.get(ExecutionModeTriangles) ? "triangle" : "quad", ") ]] vertex");
11034 else
11035 entry_type = join("[[ patch(", execution.flags.get(ExecutionModeTriangles) ? "triangle" : "quad", ", ",
11036 execution.output_vertices, ") ]] vertex");
11037 break;
11038 case ExecutionModelFragment:
11039 entry_type = execution.flags.get(ExecutionModeEarlyFragmentTests) ||
11040 execution.flags.get(ExecutionModePostDepthCoverage) ?
11041 "[[ early_fragment_tests ]] fragment" :
11042 "fragment";
11043 break;
11044 case ExecutionModelTessellationControl:
11045 if (!msl_options.supports_msl_version(1, 2))
11046 SPIRV_CROSS_THROW("Tessellation requires Metal 1.2.");
11047 if (execution.flags.get(ExecutionModeIsolines))
11048 SPIRV_CROSS_THROW("Metal does not support isoline tessellation.");
11049 /* fallthrough */
11050 case ExecutionModelGLCompute:
11051 case ExecutionModelKernel:
11052 entry_type = "kernel";
11053 break;
11054 default:
11055 entry_type = "unknown";
11056 break;
11057 }
11058
11059 return entry_type + " " + return_type;
11060 }
11061
11062 // In MSL, address space qualifiers are required for all pointer or reference variables
get_argument_address_space(const SPIRVariable & argument)11063 string CompilerMSL::get_argument_address_space(const SPIRVariable &argument)
11064 {
11065 const auto &type = get<SPIRType>(argument.basetype);
11066 return get_type_address_space(type, argument.self, true);
11067 }
11068
get_type_address_space(const SPIRType & type,uint32_t id,bool argument)11069 string CompilerMSL::get_type_address_space(const SPIRType &type, uint32_t id, bool argument)
11070 {
11071 // This can be called for variable pointer contexts as well, so be very careful about which method we choose.
11072 Bitset flags;
11073 auto *var = maybe_get<SPIRVariable>(id);
11074 if (var && type.basetype == SPIRType::Struct &&
11075 (has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock)))
11076 flags = get_buffer_block_flags(id);
11077 else
11078 flags = get_decoration_bitset(id);
11079
11080 const char *addr_space = nullptr;
11081 switch (type.storage)
11082 {
11083 case StorageClassWorkgroup:
11084 addr_space = "threadgroup";
11085 break;
11086
11087 case StorageClassStorageBuffer:
11088 {
11089 // For arguments from variable pointers, we use the write count deduction, so
11090 // we should not assume any constness here. Only for global SSBOs.
11091 bool readonly = false;
11092 if (!var || has_decoration(type.self, DecorationBlock))
11093 readonly = flags.get(DecorationNonWritable);
11094
11095 addr_space = readonly ? "const device" : "device";
11096 break;
11097 }
11098
11099 case StorageClassUniform:
11100 case StorageClassUniformConstant:
11101 case StorageClassPushConstant:
11102 if (type.basetype == SPIRType::Struct)
11103 {
11104 bool ssbo = has_decoration(type.self, DecorationBufferBlock);
11105 if (ssbo)
11106 addr_space = flags.get(DecorationNonWritable) ? "const device" : "device";
11107 else
11108 addr_space = "constant";
11109 }
11110 else if (!argument)
11111 {
11112 addr_space = "constant";
11113 }
11114 else if (type_is_msl_framebuffer_fetch(type))
11115 {
11116 // Subpass inputs are passed around by value.
11117 addr_space = "";
11118 }
11119 break;
11120
11121 case StorageClassFunction:
11122 case StorageClassGeneric:
11123 break;
11124
11125 case StorageClassInput:
11126 if (get_execution_model() == ExecutionModelTessellationControl && var &&
11127 var->basevariable == stage_in_ptr_var_id)
11128 addr_space = msl_options.multi_patch_workgroup ? "constant" : "threadgroup";
11129 if (get_execution_model() == ExecutionModelFragment && var && var->basevariable == stage_in_var_id)
11130 addr_space = "thread";
11131 break;
11132
11133 case StorageClassOutput:
11134 if (capture_output_to_buffer)
11135 {
11136 if (var && type.storage == StorageClassOutput)
11137 {
11138 bool is_masked = is_stage_output_variable_masked(*var);
11139
11140 if (is_masked)
11141 {
11142 if (is_tessellation_shader())
11143 addr_space = "threadgroup";
11144 else
11145 addr_space = "thread";
11146 }
11147 else if (variable_decl_is_remapped_storage(*var, StorageClassWorkgroup))
11148 addr_space = "threadgroup";
11149 }
11150
11151 if (!addr_space)
11152 addr_space = "device";
11153 }
11154 break;
11155
11156 default:
11157 break;
11158 }
11159
11160 if (!addr_space)
11161 {
11162 // No address space for plain values.
11163 addr_space = type.pointer || (argument && type.basetype == SPIRType::ControlPointArray) ? "thread" : "";
11164 }
11165
11166 return join(flags.get(DecorationVolatile) || flags.get(DecorationCoherent) ? "volatile " : "", addr_space);
11167 }
11168
to_restrict(uint32_t id,bool space)11169 const char *CompilerMSL::to_restrict(uint32_t id, bool space)
11170 {
11171 // This can be called for variable pointer contexts as well, so be very careful about which method we choose.
11172 Bitset flags;
11173 if (ir.ids[id].get_type() == TypeVariable)
11174 {
11175 uint32_t type_id = expression_type_id(id);
11176 auto &type = expression_type(id);
11177 if (type.basetype == SPIRType::Struct &&
11178 (has_decoration(type_id, DecorationBlock) || has_decoration(type_id, DecorationBufferBlock)))
11179 flags = get_buffer_block_flags(id);
11180 else
11181 flags = get_decoration_bitset(id);
11182 }
11183 else
11184 flags = get_decoration_bitset(id);
11185
11186 return flags.get(DecorationRestrict) ? (space ? "restrict " : "restrict") : "";
11187 }
11188
entry_point_arg_stage_in()11189 string CompilerMSL::entry_point_arg_stage_in()
11190 {
11191 string decl;
11192
11193 if (get_execution_model() == ExecutionModelTessellationControl && msl_options.multi_patch_workgroup)
11194 return decl;
11195
11196 // Stage-in structure
11197 uint32_t stage_in_id;
11198 if (get_execution_model() == ExecutionModelTessellationEvaluation)
11199 stage_in_id = patch_stage_in_var_id;
11200 else
11201 stage_in_id = stage_in_var_id;
11202
11203 if (stage_in_id)
11204 {
11205 auto &var = get<SPIRVariable>(stage_in_id);
11206 auto &type = get_variable_data_type(var);
11207
11208 add_resource_name(var.self);
11209 decl = join(type_to_glsl(type), " ", to_name(var.self), " [[stage_in]]");
11210 }
11211
11212 return decl;
11213 }
11214
11215 // Returns true if this input builtin should be a direct parameter on a shader function parameter list,
11216 // and false for builtins that should be passed or calculated some other way.
is_direct_input_builtin(BuiltIn bi_type)11217 bool CompilerMSL::is_direct_input_builtin(BuiltIn bi_type)
11218 {
11219 switch (bi_type)
11220 {
11221 // Vertex function in
11222 case BuiltInVertexId:
11223 case BuiltInVertexIndex:
11224 case BuiltInBaseVertex:
11225 case BuiltInInstanceId:
11226 case BuiltInInstanceIndex:
11227 case BuiltInBaseInstance:
11228 return get_execution_model() != ExecutionModelVertex || !msl_options.vertex_for_tessellation;
11229 // Tess. control function in
11230 case BuiltInPosition:
11231 case BuiltInPointSize:
11232 case BuiltInClipDistance:
11233 case BuiltInCullDistance:
11234 case BuiltInPatchVertices:
11235 return false;
11236 case BuiltInInvocationId:
11237 case BuiltInPrimitiveId:
11238 return get_execution_model() != ExecutionModelTessellationControl || !msl_options.multi_patch_workgroup;
11239 // Tess. evaluation function in
11240 case BuiltInTessLevelInner:
11241 case BuiltInTessLevelOuter:
11242 return false;
11243 // Fragment function in
11244 case BuiltInSamplePosition:
11245 case BuiltInHelperInvocation:
11246 case BuiltInBaryCoordNV:
11247 case BuiltInBaryCoordNoPerspNV:
11248 return false;
11249 case BuiltInViewIndex:
11250 return get_execution_model() == ExecutionModelFragment && msl_options.multiview &&
11251 msl_options.multiview_layered_rendering;
11252 // Compute function in
11253 case BuiltInSubgroupId:
11254 case BuiltInNumSubgroups:
11255 return !msl_options.emulate_subgroups;
11256 // Any stage function in
11257 case BuiltInDeviceIndex:
11258 case BuiltInSubgroupEqMask:
11259 case BuiltInSubgroupGeMask:
11260 case BuiltInSubgroupGtMask:
11261 case BuiltInSubgroupLeMask:
11262 case BuiltInSubgroupLtMask:
11263 return false;
11264 case BuiltInSubgroupSize:
11265 if (msl_options.fixed_subgroup_size != 0)
11266 return false;
11267 /* fallthrough */
11268 case BuiltInSubgroupLocalInvocationId:
11269 return !msl_options.emulate_subgroups;
11270 default:
11271 return true;
11272 }
11273 }
11274
11275 // Returns true if this is a fragment shader that runs per sample, and false otherwise.
is_sample_rate() const11276 bool CompilerMSL::is_sample_rate() const
11277 {
11278 auto &caps = get_declared_capabilities();
11279 return get_execution_model() == ExecutionModelFragment &&
11280 (msl_options.force_sample_rate_shading ||
11281 std::find(caps.begin(), caps.end(), CapabilitySampleRateShading) != caps.end() ||
11282 (msl_options.use_framebuffer_fetch_subpasses && need_subpass_input));
11283 }
11284
entry_point_args_builtin(string & ep_args)11285 void CompilerMSL::entry_point_args_builtin(string &ep_args)
11286 {
11287 // Builtin variables
11288 SmallVector<pair<SPIRVariable *, BuiltIn>, 8> active_builtins;
11289 ir.for_each_typed_id<SPIRVariable>([&](uint32_t var_id, SPIRVariable &var) {
11290 if (var.storage != StorageClassInput)
11291 return;
11292
11293 auto bi_type = BuiltIn(get_decoration(var_id, DecorationBuiltIn));
11294
11295 // Don't emit SamplePosition as a separate parameter. In the entry
11296 // point, we get that by calling get_sample_position() on the sample ID.
11297 if (is_builtin_variable(var) &&
11298 get_variable_data_type(var).basetype != SPIRType::Struct &&
11299 get_variable_data_type(var).basetype != SPIRType::ControlPointArray)
11300 {
11301 // If the builtin is not part of the active input builtin set, don't emit it.
11302 // Relevant for multiple entry-point modules which might declare unused builtins.
11303 if (!active_input_builtins.get(bi_type) || !interface_variable_exists_in_entry_point(var_id))
11304 return;
11305
11306 // Remember this variable. We may need to correct its type.
11307 active_builtins.push_back(make_pair(&var, bi_type));
11308
11309 if (is_direct_input_builtin(bi_type))
11310 {
11311 if (!ep_args.empty())
11312 ep_args += ", ";
11313
11314 // Handle HLSL-style 0-based vertex/instance index.
11315 builtin_declaration = true;
11316 ep_args += builtin_type_decl(bi_type, var_id) + " " + to_expression(var_id);
11317 ep_args += " [[" + builtin_qualifier(bi_type);
11318 if (bi_type == BuiltInSampleMask && get_entry_point().flags.get(ExecutionModePostDepthCoverage))
11319 {
11320 if (!msl_options.supports_msl_version(2))
11321 SPIRV_CROSS_THROW("Post-depth coverage requires MSL 2.0.");
11322 if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 3))
11323 SPIRV_CROSS_THROW("Post-depth coverage on Mac requires MSL 2.3.");
11324 ep_args += ", post_depth_coverage";
11325 }
11326 ep_args += "]]";
11327 builtin_declaration = false;
11328 }
11329 }
11330
11331 if (has_extended_decoration(var_id, SPIRVCrossDecorationBuiltInDispatchBase))
11332 {
11333 // This is a special implicit builtin, not corresponding to any SPIR-V builtin,
11334 // which holds the base that was passed to vkCmdDispatchBase() or vkCmdDrawIndexed(). If it's present,
11335 // assume we emitted it for a good reason.
11336 assert(msl_options.supports_msl_version(1, 2));
11337 if (!ep_args.empty())
11338 ep_args += ", ";
11339
11340 ep_args += type_to_glsl(get_variable_data_type(var)) + " " + to_expression(var_id) + " [[grid_origin]]";
11341 }
11342
11343 if (has_extended_decoration(var_id, SPIRVCrossDecorationBuiltInStageInputSize))
11344 {
11345 // This is another special implicit builtin, not corresponding to any SPIR-V builtin,
11346 // which holds the number of vertices and instances to draw. If it's present,
11347 // assume we emitted it for a good reason.
11348 assert(msl_options.supports_msl_version(1, 2));
11349 if (!ep_args.empty())
11350 ep_args += ", ";
11351
11352 ep_args += type_to_glsl(get_variable_data_type(var)) + " " + to_expression(var_id) + " [[grid_size]]";
11353 }
11354 });
11355
11356 // Correct the types of all encountered active builtins. We couldn't do this before
11357 // because ensure_correct_builtin_type() may increase the bound, which isn't allowed
11358 // while iterating over IDs.
11359 for (auto &var : active_builtins)
11360 var.first->basetype = ensure_correct_builtin_type(var.first->basetype, var.second);
11361
11362 // Handle HLSL-style 0-based vertex/instance index.
11363 if (needs_base_vertex_arg == TriState::Yes)
11364 ep_args += built_in_func_arg(BuiltInBaseVertex, !ep_args.empty());
11365
11366 if (needs_base_instance_arg == TriState::Yes)
11367 ep_args += built_in_func_arg(BuiltInBaseInstance, !ep_args.empty());
11368
11369 if (capture_output_to_buffer)
11370 {
11371 // Add parameters to hold the indirect draw parameters and the shader output. This has to be handled
11372 // specially because it needs to be a pointer, not a reference.
11373 if (stage_out_var_id)
11374 {
11375 if (!ep_args.empty())
11376 ep_args += ", ";
11377 ep_args += join("device ", type_to_glsl(get_stage_out_struct_type()), "* ", output_buffer_var_name,
11378 " [[buffer(", msl_options.shader_output_buffer_index, ")]]");
11379 }
11380
11381 if (get_execution_model() == ExecutionModelTessellationControl)
11382 {
11383 if (!ep_args.empty())
11384 ep_args += ", ";
11385 ep_args +=
11386 join("constant uint* spvIndirectParams [[buffer(", msl_options.indirect_params_buffer_index, ")]]");
11387 }
11388 else if (stage_out_var_id &&
11389 !(get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation))
11390 {
11391 if (!ep_args.empty())
11392 ep_args += ", ";
11393 ep_args +=
11394 join("device uint* spvIndirectParams [[buffer(", msl_options.indirect_params_buffer_index, ")]]");
11395 }
11396
11397 if (get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation &&
11398 (active_input_builtins.get(BuiltInVertexIndex) || active_input_builtins.get(BuiltInVertexId)) &&
11399 msl_options.vertex_index_type != Options::IndexType::None)
11400 {
11401 // Add the index buffer so we can set gl_VertexIndex correctly.
11402 if (!ep_args.empty())
11403 ep_args += ", ";
11404 switch (msl_options.vertex_index_type)
11405 {
11406 case Options::IndexType::None:
11407 break;
11408 case Options::IndexType::UInt16:
11409 ep_args += join("const device ushort* ", index_buffer_var_name, " [[buffer(",
11410 msl_options.shader_index_buffer_index, ")]]");
11411 break;
11412 case Options::IndexType::UInt32:
11413 ep_args += join("const device uint* ", index_buffer_var_name, " [[buffer(",
11414 msl_options.shader_index_buffer_index, ")]]");
11415 break;
11416 }
11417 }
11418
11419 // Tessellation control shaders get three additional parameters:
11420 // a buffer to hold the per-patch data, a buffer to hold the per-patch
11421 // tessellation levels, and a block of workgroup memory to hold the
11422 // input control point data.
11423 if (get_execution_model() == ExecutionModelTessellationControl)
11424 {
11425 if (patch_stage_out_var_id)
11426 {
11427 if (!ep_args.empty())
11428 ep_args += ", ";
11429 ep_args +=
11430 join("device ", type_to_glsl(get_patch_stage_out_struct_type()), "* ", patch_output_buffer_var_name,
11431 " [[buffer(", convert_to_string(msl_options.shader_patch_output_buffer_index), ")]]");
11432 }
11433 if (!ep_args.empty())
11434 ep_args += ", ";
11435 ep_args += join("device ", get_tess_factor_struct_name(), "* ", tess_factor_buffer_var_name, " [[buffer(",
11436 convert_to_string(msl_options.shader_tess_factor_buffer_index), ")]]");
11437
11438 // Initializer for tess factors must be handled specially since it's never declared as a normal variable.
11439 uint32_t outer_factor_initializer_id = 0;
11440 uint32_t inner_factor_initializer_id = 0;
11441 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
11442 if (!has_decoration(var.self, DecorationBuiltIn) || var.storage != StorageClassOutput || !var.initializer)
11443 return;
11444
11445 BuiltIn builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
11446 if (builtin == BuiltInTessLevelInner)
11447 inner_factor_initializer_id = var.initializer;
11448 else if (builtin == BuiltInTessLevelOuter)
11449 outer_factor_initializer_id = var.initializer;
11450 });
11451
11452 const SPIRConstant *c = nullptr;
11453
11454 if (outer_factor_initializer_id && (c = maybe_get<SPIRConstant>(outer_factor_initializer_id)))
11455 {
11456 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
11457 entry_func.fixup_hooks_in.push_back([=]() {
11458 uint32_t components = get_execution_mode_bitset().get(ExecutionModeTriangles) ? 3 : 4;
11459 for (uint32_t i = 0; i < components; i++)
11460 {
11461 statement(builtin_to_glsl(BuiltInTessLevelOuter, StorageClassOutput), "[", i, "] = ",
11462 "half(", to_expression(c->subconstants[i]), ");");
11463 }
11464 });
11465 }
11466
11467 if (inner_factor_initializer_id && (c = maybe_get<SPIRConstant>(inner_factor_initializer_id)))
11468 {
11469 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
11470 if (get_execution_mode_bitset().get(ExecutionModeTriangles))
11471 {
11472 entry_func.fixup_hooks_in.push_back([=]() {
11473 statement(builtin_to_glsl(BuiltInTessLevelInner, StorageClassOutput), " = ", "half(",
11474 to_expression(c->subconstants[0]), ");");
11475 });
11476 }
11477 else
11478 {
11479 entry_func.fixup_hooks_in.push_back([=]() {
11480 for (uint32_t i = 0; i < 2; i++)
11481 {
11482 statement(builtin_to_glsl(BuiltInTessLevelInner, StorageClassOutput), "[", i, "] = ",
11483 "half(", to_expression(c->subconstants[i]), ");");
11484 }
11485 });
11486 }
11487 }
11488
11489 if (stage_in_var_id)
11490 {
11491 if (!ep_args.empty())
11492 ep_args += ", ";
11493 if (msl_options.multi_patch_workgroup)
11494 {
11495 ep_args += join("device ", type_to_glsl(get_stage_in_struct_type()), "* ", input_buffer_var_name,
11496 " [[buffer(", convert_to_string(msl_options.shader_input_buffer_index), ")]]");
11497 }
11498 else
11499 {
11500 ep_args += join("threadgroup ", type_to_glsl(get_stage_in_struct_type()), "* ", input_wg_var_name,
11501 " [[threadgroup(", convert_to_string(msl_options.shader_input_wg_index), ")]]");
11502 }
11503 }
11504 }
11505 }
11506 }
11507
entry_point_args_argument_buffer(bool append_comma)11508 string CompilerMSL::entry_point_args_argument_buffer(bool append_comma)
11509 {
11510 string ep_args = entry_point_arg_stage_in();
11511 Bitset claimed_bindings;
11512
11513 for (uint32_t i = 0; i < kMaxArgumentBuffers; i++)
11514 {
11515 uint32_t id = argument_buffer_ids[i];
11516 if (id == 0)
11517 continue;
11518
11519 add_resource_name(id);
11520 auto &var = get<SPIRVariable>(id);
11521 auto &type = get_variable_data_type(var);
11522
11523 if (!ep_args.empty())
11524 ep_args += ", ";
11525
11526 // Check if the argument buffer binding itself has been remapped.
11527 uint32_t buffer_binding;
11528 auto itr = resource_bindings.find({ get_entry_point().model, i, kArgumentBufferBinding });
11529 if (itr != end(resource_bindings))
11530 {
11531 buffer_binding = itr->second.first.msl_buffer;
11532 itr->second.second = true;
11533 }
11534 else
11535 {
11536 // As a fallback, directly map desc set <-> binding.
11537 // If that was taken, take the next buffer binding.
11538 if (claimed_bindings.get(i))
11539 buffer_binding = next_metal_resource_index_buffer;
11540 else
11541 buffer_binding = i;
11542 }
11543
11544 claimed_bindings.set(buffer_binding);
11545
11546 ep_args += get_argument_address_space(var) + " " + type_to_glsl(type) + "& " + to_restrict(id) + to_name(id);
11547 ep_args += " [[buffer(" + convert_to_string(buffer_binding) + ")]]";
11548
11549 next_metal_resource_index_buffer = max(next_metal_resource_index_buffer, buffer_binding + 1);
11550 }
11551
11552 entry_point_args_discrete_descriptors(ep_args);
11553 entry_point_args_builtin(ep_args);
11554
11555 if (!ep_args.empty() && append_comma)
11556 ep_args += ", ";
11557
11558 return ep_args;
11559 }
11560
find_constexpr_sampler(uint32_t id) const11561 const MSLConstexprSampler *CompilerMSL::find_constexpr_sampler(uint32_t id) const
11562 {
11563 // Try by ID.
11564 {
11565 auto itr = constexpr_samplers_by_id.find(id);
11566 if (itr != end(constexpr_samplers_by_id))
11567 return &itr->second;
11568 }
11569
11570 // Try by binding.
11571 {
11572 uint32_t desc_set = get_decoration(id, DecorationDescriptorSet);
11573 uint32_t binding = get_decoration(id, DecorationBinding);
11574
11575 auto itr = constexpr_samplers_by_binding.find({ desc_set, binding });
11576 if (itr != end(constexpr_samplers_by_binding))
11577 return &itr->second;
11578 }
11579
11580 return nullptr;
11581 }
11582
entry_point_args_discrete_descriptors(string & ep_args)11583 void CompilerMSL::entry_point_args_discrete_descriptors(string &ep_args)
11584 {
11585 // Output resources, sorted by resource index & type
11586 // We need to sort to work around a bug on macOS 10.13 with NVidia drivers where switching between shaders
11587 // with different order of buffers can result in issues with buffer assignments inside the driver.
11588 struct Resource
11589 {
11590 SPIRVariable *var;
11591 string name;
11592 SPIRType::BaseType basetype;
11593 uint32_t index;
11594 uint32_t plane;
11595 uint32_t secondary_index;
11596 };
11597
11598 SmallVector<Resource> resources;
11599
11600 ir.for_each_typed_id<SPIRVariable>([&](uint32_t var_id, SPIRVariable &var) {
11601 if ((var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant ||
11602 var.storage == StorageClassPushConstant || var.storage == StorageClassStorageBuffer) &&
11603 !is_hidden_variable(var))
11604 {
11605 auto &type = get_variable_data_type(var);
11606
11607 if (is_supported_argument_buffer_type(type) && var.storage != StorageClassPushConstant)
11608 {
11609 uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet);
11610 if (descriptor_set_is_argument_buffer(desc_set))
11611 return;
11612 }
11613
11614 const MSLConstexprSampler *constexpr_sampler = nullptr;
11615 if (type.basetype == SPIRType::SampledImage || type.basetype == SPIRType::Sampler)
11616 {
11617 constexpr_sampler = find_constexpr_sampler(var_id);
11618 if (constexpr_sampler)
11619 {
11620 // Mark this ID as a constexpr sampler for later in case it came from set/bindings.
11621 constexpr_samplers_by_id[var_id] = *constexpr_sampler;
11622 }
11623 }
11624
11625 // Emulate texture2D atomic operations
11626 uint32_t secondary_index = 0;
11627 if (atomic_image_vars.count(var.self))
11628 {
11629 secondary_index = get_metal_resource_index(var, SPIRType::AtomicCounter, 0);
11630 }
11631
11632 if (type.basetype == SPIRType::SampledImage)
11633 {
11634 add_resource_name(var_id);
11635
11636 uint32_t plane_count = 1;
11637 if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
11638 plane_count = constexpr_sampler->planes;
11639
11640 for (uint32_t i = 0; i < plane_count; i++)
11641 resources.push_back({ &var, to_name(var_id), SPIRType::Image,
11642 get_metal_resource_index(var, SPIRType::Image, i), i, secondary_index });
11643
11644 if (type.image.dim != DimBuffer && !constexpr_sampler)
11645 {
11646 resources.push_back({ &var, to_sampler_expression(var_id), SPIRType::Sampler,
11647 get_metal_resource_index(var, SPIRType::Sampler), 0, 0 });
11648 }
11649 }
11650 else if (!constexpr_sampler)
11651 {
11652 // constexpr samplers are not declared as resources.
11653 add_resource_name(var_id);
11654 resources.push_back({ &var, to_name(var_id), type.basetype,
11655 get_metal_resource_index(var, type.basetype), 0, secondary_index });
11656 }
11657 }
11658 });
11659
11660 sort(resources.begin(), resources.end(), [](const Resource &lhs, const Resource &rhs) {
11661 return tie(lhs.basetype, lhs.index) < tie(rhs.basetype, rhs.index);
11662 });
11663
11664 for (auto &r : resources)
11665 {
11666 auto &var = *r.var;
11667 auto &type = get_variable_data_type(var);
11668
11669 uint32_t var_id = var.self;
11670
11671 switch (r.basetype)
11672 {
11673 case SPIRType::Struct:
11674 {
11675 auto &m = ir.meta[type.self];
11676 if (m.members.size() == 0)
11677 break;
11678 if (!type.array.empty())
11679 {
11680 if (type.array.size() > 1)
11681 SPIRV_CROSS_THROW("Arrays of arrays of buffers are not supported.");
11682
11683 // Metal doesn't directly support this, so we must expand the
11684 // array. We'll declare a local array to hold these elements
11685 // later.
11686 uint32_t array_size = to_array_size_literal(type);
11687
11688 if (array_size == 0)
11689 SPIRV_CROSS_THROW("Unsized arrays of buffers are not supported in MSL.");
11690
11691 // Allow Metal to use the array<T> template to make arrays a value type
11692 is_using_builtin_array = true;
11693 buffer_arrays.push_back(var_id);
11694 for (uint32_t i = 0; i < array_size; ++i)
11695 {
11696 if (!ep_args.empty())
11697 ep_args += ", ";
11698 ep_args += get_argument_address_space(var) + " " + type_to_glsl(type) + "* " + to_restrict(var_id) +
11699 r.name + "_" + convert_to_string(i);
11700 ep_args += " [[buffer(" + convert_to_string(r.index + i) + ")";
11701 if (interlocked_resources.count(var_id))
11702 ep_args += ", raster_order_group(0)";
11703 ep_args += "]]";
11704 }
11705 is_using_builtin_array = false;
11706 }
11707 else
11708 {
11709 if (!ep_args.empty())
11710 ep_args += ", ";
11711 ep_args +=
11712 get_argument_address_space(var) + " " + type_to_glsl(type) + "& " + to_restrict(var_id) + r.name;
11713 ep_args += " [[buffer(" + convert_to_string(r.index) + ")";
11714 if (interlocked_resources.count(var_id))
11715 ep_args += ", raster_order_group(0)";
11716 ep_args += "]]";
11717 }
11718 break;
11719 }
11720 case SPIRType::Sampler:
11721 if (!ep_args.empty())
11722 ep_args += ", ";
11723 ep_args += sampler_type(type, var_id) + " " + r.name;
11724 ep_args += " [[sampler(" + convert_to_string(r.index) + ")]]";
11725 break;
11726 case SPIRType::Image:
11727 {
11728 if (!ep_args.empty())
11729 ep_args += ", ";
11730
11731 // Use Metal's native frame-buffer fetch API for subpass inputs.
11732 const auto &basetype = get<SPIRType>(var.basetype);
11733 if (!type_is_msl_framebuffer_fetch(basetype))
11734 {
11735 ep_args += image_type_glsl(type, var_id) + " " + r.name;
11736 if (r.plane > 0)
11737 ep_args += join(plane_name_suffix, r.plane);
11738 ep_args += " [[texture(" + convert_to_string(r.index) + ")";
11739 if (interlocked_resources.count(var_id))
11740 ep_args += ", raster_order_group(0)";
11741 ep_args += "]]";
11742 }
11743 else
11744 {
11745 if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 3))
11746 SPIRV_CROSS_THROW("Framebuffer fetch on Mac is not supported before MSL 2.3.");
11747 ep_args += image_type_glsl(type, var_id) + " " + r.name;
11748 ep_args += " [[color(" + convert_to_string(r.index) + ")]]";
11749 }
11750
11751 // Emulate texture2D atomic operations
11752 if (atomic_image_vars.count(var.self))
11753 {
11754 ep_args += ", device atomic_" + type_to_glsl(get<SPIRType>(basetype.image.type), 0);
11755 ep_args += "* " + r.name + "_atomic";
11756 ep_args += " [[buffer(" + convert_to_string(r.secondary_index) + ")";
11757 if (interlocked_resources.count(var_id))
11758 ep_args += ", raster_order_group(0)";
11759 ep_args += "]]";
11760 }
11761 break;
11762 }
11763 default:
11764 if (!ep_args.empty())
11765 ep_args += ", ";
11766 if (!type.pointer)
11767 ep_args += get_type_address_space(get<SPIRType>(var.basetype), var_id) + " " +
11768 type_to_glsl(type, var_id) + "& " + r.name;
11769 else
11770 ep_args += type_to_glsl(type, var_id) + " " + r.name;
11771 ep_args += " [[buffer(" + convert_to_string(r.index) + ")";
11772 if (interlocked_resources.count(var_id))
11773 ep_args += ", raster_order_group(0)";
11774 ep_args += "]]";
11775 break;
11776 }
11777 }
11778 }
11779
11780 // Returns a string containing a comma-delimited list of args for the entry point function
11781 // This is the "classic" method of MSL 1 when we don't have argument buffer support.
entry_point_args_classic(bool append_comma)11782 string CompilerMSL::entry_point_args_classic(bool append_comma)
11783 {
11784 string ep_args = entry_point_arg_stage_in();
11785 entry_point_args_discrete_descriptors(ep_args);
11786 entry_point_args_builtin(ep_args);
11787
11788 if (!ep_args.empty() && append_comma)
11789 ep_args += ", ";
11790
11791 return ep_args;
11792 }
11793
fix_up_shader_inputs_outputs()11794 void CompilerMSL::fix_up_shader_inputs_outputs()
11795 {
11796 auto &entry_func = this->get<SPIRFunction>(ir.default_entry_point);
11797
11798 // Emit a guard to ensure we don't execute beyond the last vertex.
11799 // Vertex shaders shouldn't have the problems with barriers in non-uniform control flow that
11800 // tessellation control shaders do, so early returns should be OK. We may need to revisit this
11801 // if it ever becomes possible to use barriers from a vertex shader.
11802 if (get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation)
11803 {
11804 entry_func.fixup_hooks_in.push_back([this]() {
11805 statement("if (any(", to_expression(builtin_invocation_id_id),
11806 " >= ", to_expression(builtin_stage_input_size_id), "))");
11807 statement(" return;");
11808 });
11809 }
11810
11811 // Look for sampled images and buffer. Add hooks to set up the swizzle constants or array lengths.
11812 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
11813 auto &type = get_variable_data_type(var);
11814 uint32_t var_id = var.self;
11815 bool ssbo = has_decoration(type.self, DecorationBufferBlock);
11816
11817 if (var.storage == StorageClassUniformConstant && !is_hidden_variable(var))
11818 {
11819 if (msl_options.swizzle_texture_samples && has_sampled_images && is_sampled_image_type(type))
11820 {
11821 entry_func.fixup_hooks_in.push_back([this, &type, &var, var_id]() {
11822 bool is_array_type = !type.array.empty();
11823
11824 uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet);
11825 if (descriptor_set_is_argument_buffer(desc_set))
11826 {
11827 statement("constant uint", is_array_type ? "* " : "& ", to_swizzle_expression(var_id),
11828 is_array_type ? " = &" : " = ", to_name(argument_buffer_ids[desc_set]),
11829 ".spvSwizzleConstants", "[",
11830 convert_to_string(get_metal_resource_index(var, SPIRType::Image)), "];");
11831 }
11832 else
11833 {
11834 // If we have an array of images, we need to be able to index into it, so take a pointer instead.
11835 statement("constant uint", is_array_type ? "* " : "& ", to_swizzle_expression(var_id),
11836 is_array_type ? " = &" : " = ", to_name(swizzle_buffer_id), "[",
11837 convert_to_string(get_metal_resource_index(var, SPIRType::Image)), "];");
11838 }
11839 });
11840 }
11841 }
11842 else if ((var.storage == StorageClassStorageBuffer || (var.storage == StorageClassUniform && ssbo)) &&
11843 !is_hidden_variable(var))
11844 {
11845 if (buffers_requiring_array_length.count(var.self))
11846 {
11847 entry_func.fixup_hooks_in.push_back([this, &type, &var, var_id]() {
11848 bool is_array_type = !type.array.empty();
11849
11850 uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet);
11851 if (descriptor_set_is_argument_buffer(desc_set))
11852 {
11853 statement("constant uint", is_array_type ? "* " : "& ", to_buffer_size_expression(var_id),
11854 is_array_type ? " = &" : " = ", to_name(argument_buffer_ids[desc_set]),
11855 ".spvBufferSizeConstants", "[",
11856 convert_to_string(get_metal_resource_index(var, SPIRType::Image)), "];");
11857 }
11858 else
11859 {
11860 // If we have an array of images, we need to be able to index into it, so take a pointer instead.
11861 statement("constant uint", is_array_type ? "* " : "& ", to_buffer_size_expression(var_id),
11862 is_array_type ? " = &" : " = ", to_name(buffer_size_buffer_id), "[",
11863 convert_to_string(get_metal_resource_index(var, type.basetype)), "];");
11864 }
11865 });
11866 }
11867 }
11868 });
11869
11870 // Builtin variables
11871 ir.for_each_typed_id<SPIRVariable>([this, &entry_func](uint32_t, SPIRVariable &var) {
11872 uint32_t var_id = var.self;
11873 BuiltIn bi_type = ir.meta[var_id].decoration.builtin_type;
11874
11875 if (var.storage != StorageClassInput && var.storage != StorageClassOutput)
11876 return;
11877 if (!interface_variable_exists_in_entry_point(var.self))
11878 return;
11879
11880 if (var.storage == StorageClassInput && is_builtin_variable(var) && active_input_builtins.get(bi_type))
11881 {
11882 switch (bi_type)
11883 {
11884 case BuiltInSamplePosition:
11885 entry_func.fixup_hooks_in.push_back([=]() {
11886 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = get_sample_position(",
11887 to_expression(builtin_sample_id_id), ");");
11888 });
11889 break;
11890 case BuiltInFragCoord:
11891 if (is_sample_rate())
11892 {
11893 entry_func.fixup_hooks_in.push_back([=]() {
11894 statement(to_expression(var_id), ".xy += get_sample_position(",
11895 to_expression(builtin_sample_id_id), ") - 0.5;");
11896 });
11897 }
11898 break;
11899 case BuiltInHelperInvocation:
11900 if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 3))
11901 SPIRV_CROSS_THROW("simd_is_helper_thread() requires version 2.3 on iOS.");
11902 else if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 1))
11903 SPIRV_CROSS_THROW("simd_is_helper_thread() requires version 2.1 on macOS.");
11904
11905 entry_func.fixup_hooks_in.push_back([=]() {
11906 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = simd_is_helper_thread();");
11907 });
11908 break;
11909 case BuiltInInvocationId:
11910 // This is direct-mapped without multi-patch workgroups.
11911 if (get_execution_model() != ExecutionModelTessellationControl || !msl_options.multi_patch_workgroup)
11912 break;
11913
11914 entry_func.fixup_hooks_in.push_back([=]() {
11915 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
11916 to_expression(builtin_invocation_id_id), ".x % ", this->get_entry_point().output_vertices,
11917 ";");
11918 });
11919 break;
11920 case BuiltInPrimitiveId:
11921 // This is natively supported by fragment and tessellation evaluation shaders.
11922 // In tessellation control shaders, this is direct-mapped without multi-patch workgroups.
11923 if (get_execution_model() != ExecutionModelTessellationControl || !msl_options.multi_patch_workgroup)
11924 break;
11925
11926 entry_func.fixup_hooks_in.push_back([=]() {
11927 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = min(",
11928 to_expression(builtin_invocation_id_id), ".x / ", this->get_entry_point().output_vertices,
11929 ", spvIndirectParams[1] - 1);");
11930 });
11931 break;
11932 case BuiltInPatchVertices:
11933 if (get_execution_model() == ExecutionModelTessellationEvaluation)
11934 entry_func.fixup_hooks_in.push_back([=]() {
11935 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
11936 to_expression(patch_stage_in_var_id), ".gl_in.size();");
11937 });
11938 else
11939 entry_func.fixup_hooks_in.push_back([=]() {
11940 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = spvIndirectParams[0];");
11941 });
11942 break;
11943 case BuiltInTessCoord:
11944 // Emit a fixup to account for the shifted domain. Don't do this for triangles;
11945 // MoltenVK will just reverse the winding order instead.
11946 if (msl_options.tess_domain_origin_lower_left && !get_entry_point().flags.get(ExecutionModeTriangles))
11947 {
11948 string tc = to_expression(var_id);
11949 entry_func.fixup_hooks_in.push_back([=]() { statement(tc, ".y = 1.0 - ", tc, ".y;"); });
11950 }
11951 break;
11952 case BuiltInSubgroupId:
11953 if (!msl_options.emulate_subgroups)
11954 break;
11955 // For subgroup emulation, this is the same as the local invocation index.
11956 entry_func.fixup_hooks_in.push_back([=]() {
11957 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
11958 to_expression(builtin_local_invocation_index_id), ";");
11959 });
11960 break;
11961 case BuiltInNumSubgroups:
11962 if (!msl_options.emulate_subgroups)
11963 break;
11964 // For subgroup emulation, this is the same as the workgroup size.
11965 entry_func.fixup_hooks_in.push_back([=]() {
11966 auto &type = expression_type(builtin_workgroup_size_id);
11967 string size_expr = to_expression(builtin_workgroup_size_id);
11968 if (type.vecsize >= 3)
11969 size_expr = join(size_expr, ".x * ", size_expr, ".y * ", size_expr, ".z");
11970 else if (type.vecsize == 2)
11971 size_expr = join(size_expr, ".x * ", size_expr, ".y");
11972 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ", size_expr, ";");
11973 });
11974 break;
11975 case BuiltInSubgroupLocalInvocationId:
11976 if (!msl_options.emulate_subgroups)
11977 break;
11978 // For subgroup emulation, assume subgroups of size 1.
11979 entry_func.fixup_hooks_in.push_back(
11980 [=]() { statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = 0;"); });
11981 break;
11982 case BuiltInSubgroupSize:
11983 if (msl_options.emulate_subgroups)
11984 {
11985 // For subgroup emulation, assume subgroups of size 1.
11986 entry_func.fixup_hooks_in.push_back(
11987 [=]() { statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = 1;"); });
11988 }
11989 else if (msl_options.fixed_subgroup_size != 0)
11990 {
11991 entry_func.fixup_hooks_in.push_back([=]() {
11992 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
11993 msl_options.fixed_subgroup_size, ";");
11994 });
11995 }
11996 break;
11997 case BuiltInSubgroupEqMask:
11998 if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 2))
11999 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.2 on iOS.");
12000 if (!msl_options.supports_msl_version(2, 1))
12001 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
12002 entry_func.fixup_hooks_in.push_back([=]() {
12003 if (msl_options.is_ios())
12004 {
12005 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ", "uint4(1 << ",
12006 to_expression(builtin_subgroup_invocation_id_id), ", uint3(0));");
12007 }
12008 else
12009 {
12010 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
12011 to_expression(builtin_subgroup_invocation_id_id), " >= 32 ? uint4(0, (1 << (",
12012 to_expression(builtin_subgroup_invocation_id_id), " - 32)), uint2(0)) : uint4(1 << ",
12013 to_expression(builtin_subgroup_invocation_id_id), ", uint3(0));");
12014 }
12015 });
12016 break;
12017 case BuiltInSubgroupGeMask:
12018 if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 2))
12019 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.2 on iOS.");
12020 if (!msl_options.supports_msl_version(2, 1))
12021 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
12022 if (msl_options.fixed_subgroup_size != 0)
12023 add_spv_func_and_recompile(SPVFuncImplSubgroupBallot);
12024 entry_func.fixup_hooks_in.push_back([=]() {
12025 // Case where index < 32, size < 32:
12026 // mask0 = bfi(0, 0xFFFFFFFF, index, size - index);
12027 // mask1 = bfi(0, 0xFFFFFFFF, 0, 0); // Gives 0
12028 // Case where index < 32 but size >= 32:
12029 // mask0 = bfi(0, 0xFFFFFFFF, index, 32 - index);
12030 // mask1 = bfi(0, 0xFFFFFFFF, 0, size - 32);
12031 // Case where index >= 32:
12032 // mask0 = bfi(0, 0xFFFFFFFF, 32, 0); // Gives 0
12033 // mask1 = bfi(0, 0xFFFFFFFF, index - 32, size - index);
12034 // This is expressed without branches to avoid divergent
12035 // control flow--hence the complicated min/max expressions.
12036 // This is further complicated by the fact that if you attempt
12037 // to bfi/bfe out-of-bounds on Metal, undefined behavior is the
12038 // result.
12039 if (msl_options.fixed_subgroup_size > 32)
12040 {
12041 // Don't use the subgroup size variable with fixed subgroup sizes,
12042 // since the variables could be defined in the wrong order.
12043 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
12044 " = uint4(insert_bits(0u, 0xFFFFFFFF, min(",
12045 to_expression(builtin_subgroup_invocation_id_id), ", 32u), (uint)max(32 - (int)",
12046 to_expression(builtin_subgroup_invocation_id_id),
12047 ", 0)), insert_bits(0u, 0xFFFFFFFF,"
12048 " (uint)max((int)",
12049 to_expression(builtin_subgroup_invocation_id_id), " - 32, 0), ",
12050 msl_options.fixed_subgroup_size, " - max(",
12051 to_expression(builtin_subgroup_invocation_id_id),
12052 ", 32u)), uint2(0));");
12053 }
12054 else if (msl_options.fixed_subgroup_size != 0)
12055 {
12056 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
12057 " = uint4(insert_bits(0u, 0xFFFFFFFF, ",
12058 to_expression(builtin_subgroup_invocation_id_id), ", ",
12059 msl_options.fixed_subgroup_size, " - ",
12060 to_expression(builtin_subgroup_invocation_id_id),
12061 "), uint3(0));");
12062 }
12063 else if (msl_options.is_ios())
12064 {
12065 // On iOS, the SIMD-group size will currently never exceed 32.
12066 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
12067 " = uint4(insert_bits(0u, 0xFFFFFFFF, ",
12068 to_expression(builtin_subgroup_invocation_id_id), ", ",
12069 to_expression(builtin_subgroup_size_id), " - ",
12070 to_expression(builtin_subgroup_invocation_id_id), "), uint3(0));");
12071 }
12072 else
12073 {
12074 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
12075 " = uint4(insert_bits(0u, 0xFFFFFFFF, min(",
12076 to_expression(builtin_subgroup_invocation_id_id), ", 32u), (uint)max(min((int)",
12077 to_expression(builtin_subgroup_size_id), ", 32) - (int)",
12078 to_expression(builtin_subgroup_invocation_id_id),
12079 ", 0)), insert_bits(0u, 0xFFFFFFFF, (uint)max((int)",
12080 to_expression(builtin_subgroup_invocation_id_id), " - 32, 0), (uint)max((int)",
12081 to_expression(builtin_subgroup_size_id), " - (int)max(",
12082 to_expression(builtin_subgroup_invocation_id_id), ", 32u), 0)), uint2(0));");
12083 }
12084 });
12085 break;
12086 case BuiltInSubgroupGtMask:
12087 if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 2))
12088 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.2 on iOS.");
12089 if (!msl_options.supports_msl_version(2, 1))
12090 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
12091 add_spv_func_and_recompile(SPVFuncImplSubgroupBallot);
12092 entry_func.fixup_hooks_in.push_back([=]() {
12093 // The same logic applies here, except now the index is one
12094 // more than the subgroup invocation ID.
12095 if (msl_options.fixed_subgroup_size > 32)
12096 {
12097 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
12098 " = uint4(insert_bits(0u, 0xFFFFFFFF, min(",
12099 to_expression(builtin_subgroup_invocation_id_id), " + 1, 32u), (uint)max(32 - (int)",
12100 to_expression(builtin_subgroup_invocation_id_id),
12101 " - 1, 0)), insert_bits(0u, 0xFFFFFFFF, (uint)max((int)",
12102 to_expression(builtin_subgroup_invocation_id_id), " + 1 - 32, 0), ",
12103 msl_options.fixed_subgroup_size, " - max(",
12104 to_expression(builtin_subgroup_invocation_id_id),
12105 " + 1, 32u)), uint2(0));");
12106 }
12107 else if (msl_options.fixed_subgroup_size != 0)
12108 {
12109 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
12110 " = uint4(insert_bits(0u, 0xFFFFFFFF, ",
12111 to_expression(builtin_subgroup_invocation_id_id), " + 1, ",
12112 msl_options.fixed_subgroup_size, " - ",
12113 to_expression(builtin_subgroup_invocation_id_id),
12114 " - 1), uint3(0));");
12115 }
12116 else if (msl_options.is_ios())
12117 {
12118 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
12119 " = uint4(insert_bits(0u, 0xFFFFFFFF, ",
12120 to_expression(builtin_subgroup_invocation_id_id), " + 1, ",
12121 to_expression(builtin_subgroup_size_id), " - ",
12122 to_expression(builtin_subgroup_invocation_id_id), " - 1), uint3(0));");
12123 }
12124 else
12125 {
12126 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
12127 " = uint4(insert_bits(0u, 0xFFFFFFFF, min(",
12128 to_expression(builtin_subgroup_invocation_id_id), " + 1, 32u), (uint)max(min((int)",
12129 to_expression(builtin_subgroup_size_id), ", 32) - (int)",
12130 to_expression(builtin_subgroup_invocation_id_id),
12131 " - 1, 0)), insert_bits(0u, 0xFFFFFFFF, (uint)max((int)",
12132 to_expression(builtin_subgroup_invocation_id_id), " + 1 - 32, 0), (uint)max((int)",
12133 to_expression(builtin_subgroup_size_id), " - (int)max(",
12134 to_expression(builtin_subgroup_invocation_id_id), " + 1, 32u), 0)), uint2(0));");
12135 }
12136 });
12137 break;
12138 case BuiltInSubgroupLeMask:
12139 if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 2))
12140 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.2 on iOS.");
12141 if (!msl_options.supports_msl_version(2, 1))
12142 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
12143 add_spv_func_and_recompile(SPVFuncImplSubgroupBallot);
12144 entry_func.fixup_hooks_in.push_back([=]() {
12145 if (msl_options.is_ios())
12146 {
12147 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
12148 " = uint4(extract_bits(0xFFFFFFFF, 0, ",
12149 to_expression(builtin_subgroup_invocation_id_id), " + 1), uint3(0));");
12150 }
12151 else
12152 {
12153 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
12154 " = uint4(extract_bits(0xFFFFFFFF, 0, min(",
12155 to_expression(builtin_subgroup_invocation_id_id),
12156 " + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)",
12157 to_expression(builtin_subgroup_invocation_id_id), " + 1 - 32, 0)), uint2(0));");
12158 }
12159 });
12160 break;
12161 case BuiltInSubgroupLtMask:
12162 if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 2))
12163 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.2 on iOS.");
12164 if (!msl_options.supports_msl_version(2, 1))
12165 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
12166 add_spv_func_and_recompile(SPVFuncImplSubgroupBallot);
12167 entry_func.fixup_hooks_in.push_back([=]() {
12168 if (msl_options.is_ios())
12169 {
12170 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
12171 " = uint4(extract_bits(0xFFFFFFFF, 0, ",
12172 to_expression(builtin_subgroup_invocation_id_id), "), uint3(0));");
12173 }
12174 else
12175 {
12176 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
12177 " = uint4(extract_bits(0xFFFFFFFF, 0, min(",
12178 to_expression(builtin_subgroup_invocation_id_id),
12179 ", 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)",
12180 to_expression(builtin_subgroup_invocation_id_id), " - 32, 0)), uint2(0));");
12181 }
12182 });
12183 break;
12184 case BuiltInViewIndex:
12185 if (!msl_options.multiview)
12186 {
12187 // According to the Vulkan spec, when not running under a multiview
12188 // render pass, ViewIndex is 0.
12189 entry_func.fixup_hooks_in.push_back([=]() {
12190 statement("const ", builtin_type_decl(bi_type), " ", to_expression(var_id), " = 0;");
12191 });
12192 }
12193 else if (msl_options.view_index_from_device_index)
12194 {
12195 // In this case, we take the view index from that of the device we're running on.
12196 entry_func.fixup_hooks_in.push_back([=]() {
12197 statement("const ", builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
12198 msl_options.device_index, ";");
12199 });
12200 // We actually don't want to set the render_target_array_index here.
12201 // Since every physical device is rendering a different view,
12202 // there's no need for layered rendering here.
12203 }
12204 else if (!msl_options.multiview_layered_rendering)
12205 {
12206 // In this case, the views are rendered one at a time. The view index, then,
12207 // is just the first part of the "view mask".
12208 entry_func.fixup_hooks_in.push_back([=]() {
12209 statement("const ", builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
12210 to_expression(view_mask_buffer_id), "[0];");
12211 });
12212 }
12213 else if (get_execution_model() == ExecutionModelFragment)
12214 {
12215 // Because we adjusted the view index in the vertex shader, we have to
12216 // adjust it back here.
12217 entry_func.fixup_hooks_in.push_back([=]() {
12218 statement(to_expression(var_id), " += ", to_expression(view_mask_buffer_id), "[0];");
12219 });
12220 }
12221 else if (get_execution_model() == ExecutionModelVertex)
12222 {
12223 // Metal provides no special support for multiview, so we smuggle
12224 // the view index in the instance index.
12225 entry_func.fixup_hooks_in.push_back([=]() {
12226 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
12227 to_expression(view_mask_buffer_id), "[0] + (", to_expression(builtin_instance_idx_id),
12228 " - ", to_expression(builtin_base_instance_id), ") % ",
12229 to_expression(view_mask_buffer_id), "[1];");
12230 statement(to_expression(builtin_instance_idx_id), " = (",
12231 to_expression(builtin_instance_idx_id), " - ",
12232 to_expression(builtin_base_instance_id), ") / ", to_expression(view_mask_buffer_id),
12233 "[1] + ", to_expression(builtin_base_instance_id), ";");
12234 });
12235 // In addition to setting the variable itself, we also need to
12236 // set the render_target_array_index with it on output. We have to
12237 // offset this by the base view index, because Metal isn't in on
12238 // our little game here.
12239 entry_func.fixup_hooks_out.push_back([=]() {
12240 statement(to_expression(builtin_layer_id), " = ", to_expression(var_id), " - ",
12241 to_expression(view_mask_buffer_id), "[0];");
12242 });
12243 }
12244 break;
12245 case BuiltInDeviceIndex:
12246 // Metal pipelines belong to the devices which create them, so we'll
12247 // need to create a MTLPipelineState for every MTLDevice in a grouped
12248 // VkDevice. We can assume, then, that the device index is constant.
12249 entry_func.fixup_hooks_in.push_back([=]() {
12250 statement("const ", builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
12251 msl_options.device_index, ";");
12252 });
12253 break;
12254 case BuiltInWorkgroupId:
12255 if (!msl_options.dispatch_base || !active_input_builtins.get(BuiltInWorkgroupId))
12256 break;
12257
12258 // The vkCmdDispatchBase() command lets the client set the base value
12259 // of WorkgroupId. Metal has no direct equivalent; we must make this
12260 // adjustment ourselves.
12261 entry_func.fixup_hooks_in.push_back([=]() {
12262 statement(to_expression(var_id), " += ", to_dereferenced_expression(builtin_dispatch_base_id), ";");
12263 });
12264 break;
12265 case BuiltInGlobalInvocationId:
12266 if (!msl_options.dispatch_base || !active_input_builtins.get(BuiltInGlobalInvocationId))
12267 break;
12268
12269 // GlobalInvocationId is defined as LocalInvocationId + WorkgroupId * WorkgroupSize.
12270 // This needs to be adjusted too.
12271 entry_func.fixup_hooks_in.push_back([=]() {
12272 auto &execution = this->get_entry_point();
12273 uint32_t workgroup_size_id = execution.workgroup_size.constant;
12274 if (workgroup_size_id)
12275 statement(to_expression(var_id), " += ", to_dereferenced_expression(builtin_dispatch_base_id),
12276 " * ", to_expression(workgroup_size_id), ";");
12277 else
12278 statement(to_expression(var_id), " += ", to_dereferenced_expression(builtin_dispatch_base_id),
12279 " * uint3(", execution.workgroup_size.x, ", ", execution.workgroup_size.y, ", ",
12280 execution.workgroup_size.z, ");");
12281 });
12282 break;
12283 case BuiltInVertexId:
12284 case BuiltInVertexIndex:
12285 // This is direct-mapped normally.
12286 if (!msl_options.vertex_for_tessellation)
12287 break;
12288
12289 entry_func.fixup_hooks_in.push_back([=]() {
12290 builtin_declaration = true;
12291 switch (msl_options.vertex_index_type)
12292 {
12293 case Options::IndexType::None:
12294 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
12295 to_expression(builtin_invocation_id_id), ".x + ",
12296 to_expression(builtin_dispatch_base_id), ".x;");
12297 break;
12298 case Options::IndexType::UInt16:
12299 case Options::IndexType::UInt32:
12300 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ", index_buffer_var_name,
12301 "[", to_expression(builtin_invocation_id_id), ".x] + ",
12302 to_expression(builtin_dispatch_base_id), ".x;");
12303 break;
12304 }
12305 builtin_declaration = false;
12306 });
12307 break;
12308 case BuiltInBaseVertex:
12309 // This is direct-mapped normally.
12310 if (!msl_options.vertex_for_tessellation)
12311 break;
12312
12313 entry_func.fixup_hooks_in.push_back([=]() {
12314 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
12315 to_expression(builtin_dispatch_base_id), ".x;");
12316 });
12317 break;
12318 case BuiltInInstanceId:
12319 case BuiltInInstanceIndex:
12320 // This is direct-mapped normally.
12321 if (!msl_options.vertex_for_tessellation)
12322 break;
12323
12324 entry_func.fixup_hooks_in.push_back([=]() {
12325 builtin_declaration = true;
12326 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
12327 to_expression(builtin_invocation_id_id), ".y + ", to_expression(builtin_dispatch_base_id),
12328 ".y;");
12329 builtin_declaration = false;
12330 });
12331 break;
12332 case BuiltInBaseInstance:
12333 // This is direct-mapped normally.
12334 if (!msl_options.vertex_for_tessellation)
12335 break;
12336
12337 entry_func.fixup_hooks_in.push_back([=]() {
12338 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
12339 to_expression(builtin_dispatch_base_id), ".y;");
12340 });
12341 break;
12342 default:
12343 break;
12344 }
12345 }
12346 else if (var.storage == StorageClassOutput && get_execution_model() == ExecutionModelFragment &&
12347 is_builtin_variable(var) && active_output_builtins.get(bi_type) &&
12348 bi_type == BuiltInSampleMask && has_additional_fixed_sample_mask())
12349 {
12350 // If the additional fixed sample mask was set, we need to adjust the sample_mask
12351 // output to reflect that. If the shader outputs the sample_mask itself too, we need
12352 // to AND the two masks to get the final one.
12353 string op_str = does_shader_write_sample_mask ? " &= " : " = ";
12354 entry_func.fixup_hooks_out.push_back([=]() {
12355 statement(to_expression(builtin_sample_mask_id), op_str, additional_fixed_sample_mask_str(), ";");
12356 });
12357 }
12358 });
12359 }
12360
12361 // Returns the Metal index of the resource of the specified type as used by the specified variable.
get_metal_resource_index(SPIRVariable & var,SPIRType::BaseType basetype,uint32_t plane)12362 uint32_t CompilerMSL::get_metal_resource_index(SPIRVariable &var, SPIRType::BaseType basetype, uint32_t plane)
12363 {
12364 auto &execution = get_entry_point();
12365 auto &var_dec = ir.meta[var.self].decoration;
12366 auto &var_type = get<SPIRType>(var.basetype);
12367 uint32_t var_desc_set = (var.storage == StorageClassPushConstant) ? kPushConstDescSet : var_dec.set;
12368 uint32_t var_binding = (var.storage == StorageClassPushConstant) ? kPushConstBinding : var_dec.binding;
12369
12370 // If a matching binding has been specified, find and use it.
12371 auto itr = resource_bindings.find({ execution.model, var_desc_set, var_binding });
12372
12373 // Atomic helper buffers for image atomics need to use secondary bindings as well.
12374 bool use_secondary_binding = (var_type.basetype == SPIRType::SampledImage && basetype == SPIRType::Sampler) ||
12375 basetype == SPIRType::AtomicCounter;
12376
12377 auto resource_decoration =
12378 use_secondary_binding ? SPIRVCrossDecorationResourceIndexSecondary : SPIRVCrossDecorationResourceIndexPrimary;
12379
12380 if (plane == 1)
12381 resource_decoration = SPIRVCrossDecorationResourceIndexTertiary;
12382 if (plane == 2)
12383 resource_decoration = SPIRVCrossDecorationResourceIndexQuaternary;
12384
12385 if (itr != end(resource_bindings))
12386 {
12387 auto &remap = itr->second;
12388 remap.second = true;
12389 switch (basetype)
12390 {
12391 case SPIRType::Image:
12392 set_extended_decoration(var.self, resource_decoration, remap.first.msl_texture + plane);
12393 return remap.first.msl_texture + plane;
12394 case SPIRType::Sampler:
12395 set_extended_decoration(var.self, resource_decoration, remap.first.msl_sampler);
12396 return remap.first.msl_sampler;
12397 default:
12398 set_extended_decoration(var.self, resource_decoration, remap.first.msl_buffer);
12399 return remap.first.msl_buffer;
12400 }
12401 }
12402
12403 // If we have already allocated an index, keep using it.
12404 if (has_extended_decoration(var.self, resource_decoration))
12405 return get_extended_decoration(var.self, resource_decoration);
12406
12407 auto &type = get<SPIRType>(var.basetype);
12408
12409 if (type_is_msl_framebuffer_fetch(type))
12410 {
12411 // Frame-buffer fetch gets its fallback resource index from the input attachment index,
12412 // which is then treated as color index.
12413 return get_decoration(var.self, DecorationInputAttachmentIndex);
12414 }
12415 else if (msl_options.enable_decoration_binding)
12416 {
12417 // Allow user to enable decoration binding.
12418 // If there is no explicit mapping of bindings to MSL, use the declared binding as a fallback.
12419 if (has_decoration(var.self, DecorationBinding))
12420 {
12421 var_binding = get_decoration(var.self, DecorationBinding);
12422 // Avoid emitting sentinel bindings.
12423 if (var_binding < 0x80000000u)
12424 return var_binding;
12425 }
12426 }
12427
12428 // If we did not explicitly remap, allocate bindings on demand.
12429 // We cannot reliably use Binding decorations since SPIR-V and MSL's binding models are very different.
12430
12431 bool allocate_argument_buffer_ids = false;
12432
12433 if (var.storage != StorageClassPushConstant)
12434 allocate_argument_buffer_ids = descriptor_set_is_argument_buffer(var_desc_set);
12435
12436 uint32_t binding_stride = 1;
12437 for (uint32_t i = 0; i < uint32_t(type.array.size()); i++)
12438 binding_stride *= to_array_size_literal(type, i);
12439
12440 assert(binding_stride != 0);
12441
12442 // If a binding has not been specified, revert to incrementing resource indices.
12443 uint32_t resource_index;
12444
12445 if (allocate_argument_buffer_ids)
12446 {
12447 // Allocate from a flat ID binding space.
12448 resource_index = next_metal_resource_ids[var_desc_set];
12449 next_metal_resource_ids[var_desc_set] += binding_stride;
12450 }
12451 else
12452 {
12453 // Allocate from plain bindings which are allocated per resource type.
12454 switch (basetype)
12455 {
12456 case SPIRType::Image:
12457 resource_index = next_metal_resource_index_texture;
12458 next_metal_resource_index_texture += binding_stride;
12459 break;
12460 case SPIRType::Sampler:
12461 resource_index = next_metal_resource_index_sampler;
12462 next_metal_resource_index_sampler += binding_stride;
12463 break;
12464 default:
12465 resource_index = next_metal_resource_index_buffer;
12466 next_metal_resource_index_buffer += binding_stride;
12467 break;
12468 }
12469 }
12470
12471 set_extended_decoration(var.self, resource_decoration, resource_index);
12472 return resource_index;
12473 }
12474
type_is_msl_framebuffer_fetch(const SPIRType & type) const12475 bool CompilerMSL::type_is_msl_framebuffer_fetch(const SPIRType &type) const
12476 {
12477 return type.basetype == SPIRType::Image && type.image.dim == DimSubpassData &&
12478 msl_options.use_framebuffer_fetch_subpasses;
12479 }
12480
type_is_pointer(const SPIRType & type) const12481 bool CompilerMSL::type_is_pointer(const SPIRType &type) const
12482 {
12483 if (!type.pointer)
12484 return false;
12485 auto &parent_type = get<SPIRType>(type.parent_type);
12486 // Safeguards when we forget to set pointer_depth (there is an assert for it in type_to_glsl),
12487 // but the extra check shouldn't hurt.
12488 return (type.pointer_depth > parent_type.pointer_depth) || !parent_type.pointer;
12489 }
12490
type_is_pointer_to_pointer(const SPIRType & type) const12491 bool CompilerMSL::type_is_pointer_to_pointer(const SPIRType &type) const
12492 {
12493 if (!type.pointer)
12494 return false;
12495 auto &parent_type = get<SPIRType>(type.parent_type);
12496 return type.pointer_depth > parent_type.pointer_depth && type_is_pointer(parent_type);
12497 }
12498
argument_decl(const SPIRFunction::Parameter & arg)12499 string CompilerMSL::argument_decl(const SPIRFunction::Parameter &arg)
12500 {
12501 auto &var = get<SPIRVariable>(arg.id);
12502 auto &type = get_variable_data_type(var);
12503 auto &var_type = get<SPIRType>(arg.type);
12504 StorageClass type_storage = var_type.storage;
12505 bool is_pointer = var_type.pointer;
12506
12507 // If we need to modify the name of the variable, make sure we use the original variable.
12508 // Our alias is just a shadow variable.
12509 uint32_t name_id = var.self;
12510 if (arg.alias_global_variable && var.basevariable)
12511 name_id = var.basevariable;
12512
12513 bool constref = !arg.alias_global_variable && is_pointer && arg.write_count == 0;
12514 // Framebuffer fetch is plain value, const looks out of place, but it is not wrong.
12515 if (type_is_msl_framebuffer_fetch(type))
12516 constref = false;
12517
12518 bool type_is_image = type.basetype == SPIRType::Image || type.basetype == SPIRType::SampledImage ||
12519 type.basetype == SPIRType::Sampler;
12520
12521 // Arrays of images/samplers in MSL are always const.
12522 if (!type.array.empty() && type_is_image)
12523 constref = true;
12524
12525 const char *cv_qualifier = constref ? "const " : "";
12526 string decl;
12527
12528 // If this is a combined image-sampler for a 2D image with floating-point type,
12529 // we emitted the 'spvDynamicImageSampler' type, and this is *not* an alias parameter
12530 // for a global, then we need to emit a "dynamic" combined image-sampler.
12531 // Unfortunately, this is necessary to properly support passing around
12532 // combined image-samplers with Y'CbCr conversions on them.
12533 bool is_dynamic_img_sampler = !arg.alias_global_variable && type.basetype == SPIRType::SampledImage &&
12534 type.image.dim == Dim2D && type_is_floating_point(get<SPIRType>(type.image.type)) &&
12535 spv_function_implementations.count(SPVFuncImplDynamicImageSampler);
12536
12537 // Allow Metal to use the array<T> template to make arrays a value type
12538 string address_space = get_argument_address_space(var);
12539 bool builtin = has_decoration(var.self, DecorationBuiltIn);
12540 auto builtin_type = BuiltIn(get_decoration(arg.id, DecorationBuiltIn));
12541
12542 if (address_space == "threadgroup")
12543 is_using_builtin_array = true;
12544
12545 if (var.basevariable && (var.basevariable == stage_in_ptr_var_id || var.basevariable == stage_out_ptr_var_id))
12546 decl = join(cv_qualifier, type_to_glsl(type, arg.id));
12547 else if (builtin)
12548 {
12549 // Only use templated array for Clip/Cull distance when feasible.
12550 // In other scenarios, we need need to override array length for tess levels (if used as outputs),
12551 // or we need to emit the expected type for builtins (uint vs int).
12552 auto storage = get<SPIRType>(var.basetype).storage;
12553
12554 if (storage == StorageClassInput &&
12555 (builtin_type == BuiltInTessLevelInner || builtin_type == BuiltInTessLevelOuter))
12556 {
12557 is_using_builtin_array = false;
12558 }
12559 else if (builtin_type != BuiltInClipDistance && builtin_type != BuiltInCullDistance)
12560 {
12561 is_using_builtin_array = true;
12562 }
12563
12564 if (storage == StorageClassOutput && variable_storage_requires_stage_io(storage) &&
12565 !is_stage_output_builtin_masked(builtin_type))
12566 is_using_builtin_array = true;
12567
12568 if (is_using_builtin_array)
12569 decl = join(cv_qualifier, builtin_type_decl(builtin_type, arg.id));
12570 else
12571 decl = join(cv_qualifier, type_to_glsl(type, arg.id));
12572 }
12573 else if ((type_storage == StorageClassUniform || type_storage == StorageClassStorageBuffer) && is_array(type))
12574 {
12575 is_using_builtin_array = true;
12576 decl += join(cv_qualifier, type_to_glsl(type, arg.id), "*");
12577 }
12578 else if (is_dynamic_img_sampler)
12579 {
12580 decl = join(cv_qualifier, "spvDynamicImageSampler<", type_to_glsl(get<SPIRType>(type.image.type)), ">");
12581 // Mark the variable so that we can handle passing it to another function.
12582 set_extended_decoration(arg.id, SPIRVCrossDecorationDynamicImageSampler);
12583 }
12584 else
12585 {
12586 // The type is a pointer type we need to emit cv_qualifier late.
12587 if (type_is_pointer(type))
12588 {
12589 decl = type_to_glsl(type, arg.id);
12590 if (*cv_qualifier != '\0')
12591 decl += join(" ", cv_qualifier);
12592 }
12593 else
12594 decl = join(cv_qualifier, type_to_glsl(type, arg.id));
12595 }
12596
12597 bool opaque_handle = type_storage == StorageClassUniformConstant;
12598
12599 if (!builtin && !opaque_handle && !is_pointer &&
12600 (type_storage == StorageClassFunction || type_storage == StorageClassGeneric))
12601 {
12602 // If the argument is a pure value and not an opaque type, we will pass by value.
12603 if (msl_options.force_native_arrays && is_array(type))
12604 {
12605 // We are receiving an array by value. This is problematic.
12606 // We cannot be sure of the target address space since we are supposed to receive a copy,
12607 // but this is not possible with MSL without some extra work.
12608 // We will have to assume we're getting a reference in thread address space.
12609 // If we happen to get a reference in constant address space, the caller must emit a copy and pass that.
12610 // Thread const therefore becomes the only logical choice, since we cannot "create" a constant array from
12611 // non-constant arrays, but we can create thread const from constant.
12612 decl = string("thread const ") + decl;
12613 decl += " (&";
12614 const char *restrict_kw = to_restrict(name_id);
12615 if (*restrict_kw)
12616 {
12617 decl += " ";
12618 decl += restrict_kw;
12619 }
12620 decl += to_expression(name_id);
12621 decl += ")";
12622 decl += type_to_array_glsl(type);
12623 }
12624 else
12625 {
12626 if (!address_space.empty())
12627 decl = join(address_space, " ", decl);
12628 decl += " ";
12629 decl += to_expression(name_id);
12630 }
12631 }
12632 else if (is_array(type) && !type_is_image)
12633 {
12634 // Arrays of images and samplers are special cased.
12635 if (!address_space.empty())
12636 decl = join(address_space, " ", decl);
12637
12638 if (msl_options.argument_buffers)
12639 {
12640 uint32_t desc_set = get_decoration(name_id, DecorationDescriptorSet);
12641 if ((type_storage == StorageClassUniform || type_storage == StorageClassStorageBuffer) &&
12642 descriptor_set_is_argument_buffer(desc_set))
12643 {
12644 // An awkward case where we need to emit *more* address space declarations (yay!).
12645 // An example is where we pass down an array of buffer pointers to leaf functions.
12646 // It's a constant array containing pointers to constants.
12647 // The pointer array is always constant however. E.g.
12648 // device SSBO * constant (&array)[N].
12649 // const device SSBO * constant (&array)[N].
12650 // constant SSBO * constant (&array)[N].
12651 // However, this only matters for argument buffers, since for MSL 1.0 style codegen,
12652 // we emit the buffer array on stack instead, and that seems to work just fine apparently.
12653
12654 // If the argument was marked as being in device address space, any pointer to member would
12655 // be const device, not constant.
12656 if (argument_buffer_device_storage_mask & (1u << desc_set))
12657 decl += " const device";
12658 else
12659 decl += " constant";
12660 }
12661 }
12662
12663 // Special case, need to override the array size here if we're using tess level as an argument.
12664 if (get_execution_model() == ExecutionModelTessellationControl && builtin &&
12665 (builtin_type == BuiltInTessLevelInner || builtin_type == BuiltInTessLevelOuter))
12666 {
12667 uint32_t array_size = get_physical_tess_level_array_size(builtin_type);
12668 if (array_size == 1)
12669 {
12670 decl += " &";
12671 decl += to_expression(name_id);
12672 }
12673 else
12674 {
12675 decl += " (&";
12676 decl += to_expression(name_id);
12677 decl += ")";
12678 decl += join("[", array_size, "]");
12679 }
12680 }
12681 else
12682 {
12683 auto array_size_decl = type_to_array_glsl(type);
12684 if (array_size_decl.empty())
12685 decl += "& ";
12686 else
12687 decl += " (&";
12688
12689 const char *restrict_kw = to_restrict(name_id);
12690 if (*restrict_kw)
12691 {
12692 decl += " ";
12693 decl += restrict_kw;
12694 }
12695 decl += to_expression(name_id);
12696
12697 if (!array_size_decl.empty())
12698 {
12699 decl += ")";
12700 decl += array_size_decl;
12701 }
12702 }
12703 }
12704 else if (!opaque_handle && (!pull_model_inputs.count(var.basevariable) || type.basetype == SPIRType::Struct))
12705 {
12706 // If this is going to be a reference to a variable pointer, the address space
12707 // for the reference has to go before the '&', but after the '*'.
12708 if (!address_space.empty())
12709 {
12710 if (type_is_pointer(type))
12711 {
12712 if (*cv_qualifier == '\0')
12713 decl += ' ';
12714 decl += join(address_space, " ");
12715 }
12716 else
12717 decl = join(address_space, " ", decl);
12718 }
12719 decl += "&";
12720 decl += " ";
12721 decl += to_restrict(name_id);
12722 decl += to_expression(name_id);
12723 }
12724 else
12725 {
12726 if (!address_space.empty())
12727 decl = join(address_space, " ", decl);
12728 decl += " ";
12729 decl += to_expression(name_id);
12730 }
12731
12732 // Emulate texture2D atomic operations
12733 auto *backing_var = maybe_get_backing_variable(name_id);
12734 if (backing_var && atomic_image_vars.count(backing_var->self))
12735 {
12736 decl += ", device atomic_" + type_to_glsl(get<SPIRType>(var_type.image.type), 0);
12737 decl += "* " + to_expression(name_id) + "_atomic";
12738 }
12739
12740 is_using_builtin_array = false;
12741
12742 return decl;
12743 }
12744
12745 // If we're currently in the entry point function, and the object
12746 // has a qualified name, use it, otherwise use the standard name.
to_name(uint32_t id,bool allow_alias) const12747 string CompilerMSL::to_name(uint32_t id, bool allow_alias) const
12748 {
12749 if (current_function && (current_function->self == ir.default_entry_point))
12750 {
12751 auto *m = ir.find_meta(id);
12752 if (m && !m->decoration.qualified_alias.empty())
12753 return m->decoration.qualified_alias;
12754 }
12755 return Compiler::to_name(id, allow_alias);
12756 }
12757
12758 // Returns a name that combines the name of the struct with the name of the member, except for Builtins
to_qualified_member_name(const SPIRType & type,uint32_t index)12759 string CompilerMSL::to_qualified_member_name(const SPIRType &type, uint32_t index)
12760 {
12761 // Don't qualify Builtin names because they are unique and are treated as such when building expressions
12762 BuiltIn builtin = BuiltInMax;
12763 if (is_member_builtin(type, index, &builtin))
12764 return builtin_to_glsl(builtin, type.storage);
12765
12766 // Strip any underscore prefix from member name
12767 string mbr_name = to_member_name(type, index);
12768 size_t startPos = mbr_name.find_first_not_of("_");
12769 mbr_name = (startPos != string::npos) ? mbr_name.substr(startPos) : "";
12770 return join(to_name(type.self), "_", mbr_name);
12771 }
12772
12773 // Ensures that the specified name is permanently usable by prepending a prefix
12774 // if the first chars are _ and a digit, which indicate a transient name.
ensure_valid_name(string name,string pfx)12775 string CompilerMSL::ensure_valid_name(string name, string pfx)
12776 {
12777 return (name.size() >= 2 && name[0] == '_' && isdigit(name[1])) ? (pfx + name) : name;
12778 }
12779
get_reserved_keyword_set()12780 const std::unordered_set<std::string> &CompilerMSL::get_reserved_keyword_set()
12781 {
12782 static const unordered_set<string> keywords = {
12783 "kernel",
12784 "vertex",
12785 "fragment",
12786 "compute",
12787 "bias",
12788 "level",
12789 "gradient2d",
12790 "gradientcube",
12791 "gradient3d",
12792 "min_lod_clamp",
12793 "assert",
12794 "VARIABLE_TRACEPOINT",
12795 "STATIC_DATA_TRACEPOINT",
12796 "STATIC_DATA_TRACEPOINT_V",
12797 "METAL_ALIGN",
12798 "METAL_ASM",
12799 "METAL_CONST",
12800 "METAL_DEPRECATED",
12801 "METAL_ENABLE_IF",
12802 "METAL_FUNC",
12803 "METAL_INTERNAL",
12804 "METAL_NON_NULL_RETURN",
12805 "METAL_NORETURN",
12806 "METAL_NOTHROW",
12807 "METAL_PURE",
12808 "METAL_UNAVAILABLE",
12809 "METAL_IMPLICIT",
12810 "METAL_EXPLICIT",
12811 "METAL_CONST_ARG",
12812 "METAL_ARG_UNIFORM",
12813 "METAL_ZERO_ARG",
12814 "METAL_VALID_LOD_ARG",
12815 "METAL_VALID_LEVEL_ARG",
12816 "METAL_VALID_STORE_ORDER",
12817 "METAL_VALID_LOAD_ORDER",
12818 "METAL_VALID_COMPARE_EXCHANGE_FAILURE_ORDER",
12819 "METAL_COMPATIBLE_COMPARE_EXCHANGE_ORDERS",
12820 "METAL_VALID_RENDER_TARGET",
12821 "is_function_constant_defined",
12822 "CHAR_BIT",
12823 "SCHAR_MAX",
12824 "SCHAR_MIN",
12825 "UCHAR_MAX",
12826 "CHAR_MAX",
12827 "CHAR_MIN",
12828 "USHRT_MAX",
12829 "SHRT_MAX",
12830 "SHRT_MIN",
12831 "UINT_MAX",
12832 "INT_MAX",
12833 "INT_MIN",
12834 "FLT_DIG",
12835 "FLT_MANT_DIG",
12836 "FLT_MAX_10_EXP",
12837 "FLT_MAX_EXP",
12838 "FLT_MIN_10_EXP",
12839 "FLT_MIN_EXP",
12840 "FLT_RADIX",
12841 "FLT_MAX",
12842 "FLT_MIN",
12843 "FLT_EPSILON",
12844 "FP_ILOGB0",
12845 "FP_ILOGBNAN",
12846 "MAXFLOAT",
12847 "HUGE_VALF",
12848 "INFINITY",
12849 "NAN",
12850 "M_E_F",
12851 "M_LOG2E_F",
12852 "M_LOG10E_F",
12853 "M_LN2_F",
12854 "M_LN10_F",
12855 "M_PI_F",
12856 "M_PI_2_F",
12857 "M_PI_4_F",
12858 "M_1_PI_F",
12859 "M_2_PI_F",
12860 "M_2_SQRTPI_F",
12861 "M_SQRT2_F",
12862 "M_SQRT1_2_F",
12863 "HALF_DIG",
12864 "HALF_MANT_DIG",
12865 "HALF_MAX_10_EXP",
12866 "HALF_MAX_EXP",
12867 "HALF_MIN_10_EXP",
12868 "HALF_MIN_EXP",
12869 "HALF_RADIX",
12870 "HALF_MAX",
12871 "HALF_MIN",
12872 "HALF_EPSILON",
12873 "MAXHALF",
12874 "HUGE_VALH",
12875 "M_E_H",
12876 "M_LOG2E_H",
12877 "M_LOG10E_H",
12878 "M_LN2_H",
12879 "M_LN10_H",
12880 "M_PI_H",
12881 "M_PI_2_H",
12882 "M_PI_4_H",
12883 "M_1_PI_H",
12884 "M_2_PI_H",
12885 "M_2_SQRTPI_H",
12886 "M_SQRT2_H",
12887 "M_SQRT1_2_H",
12888 "DBL_DIG",
12889 "DBL_MANT_DIG",
12890 "DBL_MAX_10_EXP",
12891 "DBL_MAX_EXP",
12892 "DBL_MIN_10_EXP",
12893 "DBL_MIN_EXP",
12894 "DBL_RADIX",
12895 "DBL_MAX",
12896 "DBL_MIN",
12897 "DBL_EPSILON",
12898 "HUGE_VAL",
12899 "M_E",
12900 "M_LOG2E",
12901 "M_LOG10E",
12902 "M_LN2",
12903 "M_LN10",
12904 "M_PI",
12905 "M_PI_2",
12906 "M_PI_4",
12907 "M_1_PI",
12908 "M_2_PI",
12909 "M_2_SQRTPI",
12910 "M_SQRT2",
12911 "M_SQRT1_2",
12912 "quad_broadcast",
12913 };
12914
12915 return keywords;
12916 }
12917
get_illegal_func_names()12918 const std::unordered_set<std::string> &CompilerMSL::get_illegal_func_names()
12919 {
12920 static const unordered_set<string> illegal_func_names = {
12921 "main",
12922 "saturate",
12923 "assert",
12924 "fmin3",
12925 "fmax3",
12926 "VARIABLE_TRACEPOINT",
12927 "STATIC_DATA_TRACEPOINT",
12928 "STATIC_DATA_TRACEPOINT_V",
12929 "METAL_ALIGN",
12930 "METAL_ASM",
12931 "METAL_CONST",
12932 "METAL_DEPRECATED",
12933 "METAL_ENABLE_IF",
12934 "METAL_FUNC",
12935 "METAL_INTERNAL",
12936 "METAL_NON_NULL_RETURN",
12937 "METAL_NORETURN",
12938 "METAL_NOTHROW",
12939 "METAL_PURE",
12940 "METAL_UNAVAILABLE",
12941 "METAL_IMPLICIT",
12942 "METAL_EXPLICIT",
12943 "METAL_CONST_ARG",
12944 "METAL_ARG_UNIFORM",
12945 "METAL_ZERO_ARG",
12946 "METAL_VALID_LOD_ARG",
12947 "METAL_VALID_LEVEL_ARG",
12948 "METAL_VALID_STORE_ORDER",
12949 "METAL_VALID_LOAD_ORDER",
12950 "METAL_VALID_COMPARE_EXCHANGE_FAILURE_ORDER",
12951 "METAL_COMPATIBLE_COMPARE_EXCHANGE_ORDERS",
12952 "METAL_VALID_RENDER_TARGET",
12953 "is_function_constant_defined",
12954 "CHAR_BIT",
12955 "SCHAR_MAX",
12956 "SCHAR_MIN",
12957 "UCHAR_MAX",
12958 "CHAR_MAX",
12959 "CHAR_MIN",
12960 "USHRT_MAX",
12961 "SHRT_MAX",
12962 "SHRT_MIN",
12963 "UINT_MAX",
12964 "INT_MAX",
12965 "INT_MIN",
12966 "FLT_DIG",
12967 "FLT_MANT_DIG",
12968 "FLT_MAX_10_EXP",
12969 "FLT_MAX_EXP",
12970 "FLT_MIN_10_EXP",
12971 "FLT_MIN_EXP",
12972 "FLT_RADIX",
12973 "FLT_MAX",
12974 "FLT_MIN",
12975 "FLT_EPSILON",
12976 "FP_ILOGB0",
12977 "FP_ILOGBNAN",
12978 "MAXFLOAT",
12979 "HUGE_VALF",
12980 "INFINITY",
12981 "NAN",
12982 "M_E_F",
12983 "M_LOG2E_F",
12984 "M_LOG10E_F",
12985 "M_LN2_F",
12986 "M_LN10_F",
12987 "M_PI_F",
12988 "M_PI_2_F",
12989 "M_PI_4_F",
12990 "M_1_PI_F",
12991 "M_2_PI_F",
12992 "M_2_SQRTPI_F",
12993 "M_SQRT2_F",
12994 "M_SQRT1_2_F",
12995 "HALF_DIG",
12996 "HALF_MANT_DIG",
12997 "HALF_MAX_10_EXP",
12998 "HALF_MAX_EXP",
12999 "HALF_MIN_10_EXP",
13000 "HALF_MIN_EXP",
13001 "HALF_RADIX",
13002 "HALF_MAX",
13003 "HALF_MIN",
13004 "HALF_EPSILON",
13005 "MAXHALF",
13006 "HUGE_VALH",
13007 "M_E_H",
13008 "M_LOG2E_H",
13009 "M_LOG10E_H",
13010 "M_LN2_H",
13011 "M_LN10_H",
13012 "M_PI_H",
13013 "M_PI_2_H",
13014 "M_PI_4_H",
13015 "M_1_PI_H",
13016 "M_2_PI_H",
13017 "M_2_SQRTPI_H",
13018 "M_SQRT2_H",
13019 "M_SQRT1_2_H",
13020 "DBL_DIG",
13021 "DBL_MANT_DIG",
13022 "DBL_MAX_10_EXP",
13023 "DBL_MAX_EXP",
13024 "DBL_MIN_10_EXP",
13025 "DBL_MIN_EXP",
13026 "DBL_RADIX",
13027 "DBL_MAX",
13028 "DBL_MIN",
13029 "DBL_EPSILON",
13030 "HUGE_VAL",
13031 "M_E",
13032 "M_LOG2E",
13033 "M_LOG10E",
13034 "M_LN2",
13035 "M_LN10",
13036 "M_PI",
13037 "M_PI_2",
13038 "M_PI_4",
13039 "M_1_PI",
13040 "M_2_PI",
13041 "M_2_SQRTPI",
13042 "M_SQRT2",
13043 "M_SQRT1_2",
13044 };
13045
13046 return illegal_func_names;
13047 }
13048
13049 // Replace all names that match MSL keywords or Metal Standard Library functions.
replace_illegal_names()13050 void CompilerMSL::replace_illegal_names()
13051 {
13052 // FIXME: MSL and GLSL are doing two different things here.
13053 // Agree on convention and remove this override.
13054 auto &keywords = get_reserved_keyword_set();
13055 auto &illegal_func_names = get_illegal_func_names();
13056
13057 ir.for_each_typed_id<SPIRVariable>([&](uint32_t self, SPIRVariable &) {
13058 auto *meta = ir.find_meta(self);
13059 if (!meta)
13060 return;
13061
13062 auto &dec = meta->decoration;
13063 if (keywords.find(dec.alias) != end(keywords))
13064 dec.alias += "0";
13065 });
13066
13067 ir.for_each_typed_id<SPIRFunction>([&](uint32_t self, SPIRFunction &) {
13068 auto *meta = ir.find_meta(self);
13069 if (!meta)
13070 return;
13071
13072 auto &dec = meta->decoration;
13073 if (illegal_func_names.find(dec.alias) != end(illegal_func_names))
13074 dec.alias += "0";
13075 });
13076
13077 ir.for_each_typed_id<SPIRType>([&](uint32_t self, SPIRType &) {
13078 auto *meta = ir.find_meta(self);
13079 if (!meta)
13080 return;
13081
13082 for (auto &mbr_dec : meta->members)
13083 if (keywords.find(mbr_dec.alias) != end(keywords))
13084 mbr_dec.alias += "0";
13085 });
13086
13087 CompilerGLSL::replace_illegal_names();
13088 }
13089
replace_illegal_entry_point_names()13090 void CompilerMSL::replace_illegal_entry_point_names()
13091 {
13092 auto &illegal_func_names = get_illegal_func_names();
13093
13094 // It is important to this before we fixup identifiers,
13095 // since if ep_name is reserved, we will need to fix that up,
13096 // and then copy alias back into entry.name after the fixup.
13097 for (auto &entry : ir.entry_points)
13098 {
13099 // Change both the entry point name and the alias, to keep them synced.
13100 string &ep_name = entry.second.name;
13101 if (illegal_func_names.find(ep_name) != end(illegal_func_names))
13102 ep_name += "0";
13103
13104 ir.meta[entry.first].decoration.alias = ep_name;
13105 }
13106 }
13107
sync_entry_point_aliases_and_names()13108 void CompilerMSL::sync_entry_point_aliases_and_names()
13109 {
13110 for (auto &entry : ir.entry_points)
13111 entry.second.name = ir.meta[entry.first].decoration.alias;
13112 }
13113
to_member_reference(uint32_t base,const SPIRType & type,uint32_t index,bool ptr_chain)13114 string CompilerMSL::to_member_reference(uint32_t base, const SPIRType &type, uint32_t index, bool ptr_chain)
13115 {
13116 if (index < uint32_t(type.member_type_index_redirection.size()))
13117 index = type.member_type_index_redirection[index];
13118
13119 auto *var = maybe_get<SPIRVariable>(base);
13120 // If this is a buffer array, we have to dereference the buffer pointers.
13121 // Otherwise, if this is a pointer expression, dereference it.
13122
13123 bool declared_as_pointer = false;
13124
13125 if (var)
13126 {
13127 // Only allow -> dereference for block types. This is so we get expressions like
13128 // buffer[i]->first_member.second_member, rather than buffer[i]->first->second.
13129 bool is_block = has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock);
13130
13131 bool is_buffer_variable =
13132 is_block && (var->storage == StorageClassUniform || var->storage == StorageClassStorageBuffer);
13133 declared_as_pointer = is_buffer_variable && is_array(get<SPIRType>(var->basetype));
13134 }
13135
13136 if (declared_as_pointer || (!ptr_chain && should_dereference(base)))
13137 return join("->", to_member_name(type, index));
13138 else
13139 return join(".", to_member_name(type, index));
13140 }
13141
to_qualifiers_glsl(uint32_t id)13142 string CompilerMSL::to_qualifiers_glsl(uint32_t id)
13143 {
13144 string quals;
13145
13146 auto *var = maybe_get<SPIRVariable>(id);
13147 auto &type = expression_type(id);
13148
13149 if (type.storage == StorageClassWorkgroup || (var && variable_decl_is_remapped_storage(*var, StorageClassWorkgroup)))
13150 quals += "threadgroup ";
13151
13152 return quals;
13153 }
13154
13155 // The optional id parameter indicates the object whose type we are trying
13156 // to find the description for. It is optional. Most type descriptions do not
13157 // depend on a specific object's use of that type.
type_to_glsl(const SPIRType & type,uint32_t id)13158 string CompilerMSL::type_to_glsl(const SPIRType &type, uint32_t id)
13159 {
13160 string type_name;
13161
13162 // Pointer?
13163 if (type.pointer)
13164 {
13165 assert(type.pointer_depth > 0);
13166
13167 const char *restrict_kw;
13168
13169 auto type_address_space = get_type_address_space(type, id);
13170 auto type_decl = type_to_glsl(get<SPIRType>(type.parent_type), id);
13171
13172 // Work around C pointer qualifier rules. If glsl_type is a pointer type as well
13173 // we'll need to emit the address space to the right.
13174 // We could always go this route, but it makes the code unnatural.
13175 // Prefer emitting thread T *foo over T thread* foo since it's more readable,
13176 // but we'll have to emit thread T * thread * T constant bar; for example.
13177 if (type_is_pointer_to_pointer(type))
13178 type_name = join(type_decl, " ", type_address_space, " ");
13179 else
13180 type_name = join(type_address_space, " ", type_decl);
13181
13182 switch (type.basetype)
13183 {
13184 case SPIRType::Image:
13185 case SPIRType::SampledImage:
13186 case SPIRType::Sampler:
13187 // These are handles.
13188 break;
13189 default:
13190 // Anything else can be a raw pointer.
13191 type_name += "*";
13192 restrict_kw = to_restrict(id);
13193 if (*restrict_kw)
13194 {
13195 type_name += " ";
13196 type_name += restrict_kw;
13197 }
13198 break;
13199 }
13200 return type_name;
13201 }
13202
13203 switch (type.basetype)
13204 {
13205 case SPIRType::Struct:
13206 // Need OpName lookup here to get a "sensible" name for a struct.
13207 // Allow Metal to use the array<T> template to make arrays a value type
13208 type_name = to_name(type.self);
13209 break;
13210
13211 case SPIRType::Image:
13212 case SPIRType::SampledImage:
13213 return image_type_glsl(type, id);
13214
13215 case SPIRType::Sampler:
13216 return sampler_type(type, id);
13217
13218 case SPIRType::Void:
13219 return "void";
13220
13221 case SPIRType::AtomicCounter:
13222 return "atomic_uint";
13223
13224 case SPIRType::ControlPointArray:
13225 return join("patch_control_point<", type_to_glsl(get<SPIRType>(type.parent_type), id), ">");
13226
13227 case SPIRType::Interpolant:
13228 return join("interpolant<", type_to_glsl(get<SPIRType>(type.parent_type), id), ", interpolation::",
13229 has_decoration(type.self, DecorationNoPerspective) ? "no_perspective" : "perspective", ">");
13230
13231 // Scalars
13232 case SPIRType::Boolean:
13233 type_name = "bool";
13234 break;
13235 case SPIRType::Char:
13236 case SPIRType::SByte:
13237 type_name = "char";
13238 break;
13239 case SPIRType::UByte:
13240 type_name = "uchar";
13241 break;
13242 case SPIRType::Short:
13243 type_name = "short";
13244 break;
13245 case SPIRType::UShort:
13246 type_name = "ushort";
13247 break;
13248 case SPIRType::Int:
13249 type_name = "int";
13250 break;
13251 case SPIRType::UInt:
13252 type_name = "uint";
13253 break;
13254 case SPIRType::Int64:
13255 if (!msl_options.supports_msl_version(2, 2))
13256 SPIRV_CROSS_THROW("64-bit integers are only supported in MSL 2.2 and above.");
13257 type_name = "long";
13258 break;
13259 case SPIRType::UInt64:
13260 if (!msl_options.supports_msl_version(2, 2))
13261 SPIRV_CROSS_THROW("64-bit integers are only supported in MSL 2.2 and above.");
13262 type_name = "ulong";
13263 break;
13264 case SPIRType::Half:
13265 type_name = "half";
13266 break;
13267 case SPIRType::Float:
13268 type_name = "float";
13269 break;
13270 case SPIRType::Double:
13271 type_name = "double"; // Currently unsupported
13272 break;
13273
13274 default:
13275 return "unknown_type";
13276 }
13277
13278 // Matrix?
13279 if (type.columns > 1)
13280 type_name += to_string(type.columns) + "x";
13281
13282 // Vector or Matrix?
13283 if (type.vecsize > 1)
13284 type_name += to_string(type.vecsize);
13285
13286 if (type.array.empty() || using_builtin_array())
13287 {
13288 return type_name;
13289 }
13290 else
13291 {
13292 // Allow Metal to use the array<T> template to make arrays a value type
13293 add_spv_func_and_recompile(SPVFuncImplUnsafeArray);
13294 string res;
13295 string sizes;
13296
13297 for (uint32_t i = 0; i < uint32_t(type.array.size()); i++)
13298 {
13299 res += "spvUnsafeArray<";
13300 sizes += ", ";
13301 sizes += to_array_size(type, i);
13302 sizes += ">";
13303 }
13304
13305 res += type_name + sizes;
13306 return res;
13307 }
13308 }
13309
type_to_array_glsl(const SPIRType & type)13310 string CompilerMSL::type_to_array_glsl(const SPIRType &type)
13311 {
13312 // Allow Metal to use the array<T> template to make arrays a value type
13313 switch (type.basetype)
13314 {
13315 case SPIRType::AtomicCounter:
13316 case SPIRType::ControlPointArray:
13317 {
13318 return CompilerGLSL::type_to_array_glsl(type);
13319 }
13320 default:
13321 {
13322 if (using_builtin_array())
13323 return CompilerGLSL::type_to_array_glsl(type);
13324 else
13325 return "";
13326 }
13327 }
13328 }
13329
variable_decl_is_remapped_storage(const SPIRVariable & variable,spv::StorageClass storage) const13330 bool CompilerMSL::variable_decl_is_remapped_storage(const SPIRVariable &variable, spv::StorageClass storage) const
13331 {
13332 if (variable.storage == storage)
13333 return true;
13334
13335 if (storage == StorageClassWorkgroup)
13336 {
13337 auto model = get_execution_model();
13338
13339 // Specially masked IO block variable.
13340 // Normally, we will never access IO blocks directly here.
13341 // The only scenario which that should occur is with a masked IO block.
13342 if (model == ExecutionModelTessellationControl && variable.storage == StorageClassOutput &&
13343 has_decoration(get<SPIRType>(variable.basetype).self, DecorationBlock))
13344 {
13345 return true;
13346 }
13347
13348 return variable.storage == StorageClassOutput &&
13349 model == ExecutionModelTessellationControl &&
13350 is_stage_output_variable_masked(variable);
13351 }
13352 else if (storage == StorageClassStorageBuffer)
13353 {
13354 // We won't be able to catch writes to control point outputs here since variable
13355 // refers to a function local pointer.
13356 // This is fine, as there cannot be concurrent writers to that memory anyways,
13357 // so we just ignore that case.
13358
13359 return (variable.storage == StorageClassOutput || variable.storage == StorageClassInput) &&
13360 !variable_storage_requires_stage_io(variable.storage) &&
13361 (variable.storage != StorageClassOutput || !is_stage_output_variable_masked(variable));
13362 }
13363 else
13364 {
13365 return false;
13366 }
13367 }
13368
variable_decl(const SPIRVariable & variable)13369 std::string CompilerMSL::variable_decl(const SPIRVariable &variable)
13370 {
13371 bool old_is_using_builtin_array = is_using_builtin_array;
13372
13373 // Threadgroup arrays can't have a wrapper type.
13374 if (variable_decl_is_remapped_storage(variable, StorageClassWorkgroup))
13375 is_using_builtin_array = true;
13376
13377 std::string expr = CompilerGLSL::variable_decl(variable);
13378 is_using_builtin_array = old_is_using_builtin_array;
13379 return expr;
13380 }
13381
13382 // GCC workaround of lambdas calling protected funcs
variable_decl(const SPIRType & type,const std::string & name,uint32_t id)13383 std::string CompilerMSL::variable_decl(const SPIRType &type, const std::string &name, uint32_t id)
13384 {
13385 return CompilerGLSL::variable_decl(type, name, id);
13386 }
13387
sampler_type(const SPIRType & type,uint32_t id)13388 std::string CompilerMSL::sampler_type(const SPIRType &type, uint32_t id)
13389 {
13390 auto *var = maybe_get<SPIRVariable>(id);
13391 if (var && var->basevariable)
13392 {
13393 // Check against the base variable, and not a fake ID which might have been generated for this variable.
13394 id = var->basevariable;
13395 }
13396
13397 if (!type.array.empty())
13398 {
13399 if (!msl_options.supports_msl_version(2))
13400 SPIRV_CROSS_THROW("MSL 2.0 or greater is required for arrays of samplers.");
13401
13402 if (type.array.size() > 1)
13403 SPIRV_CROSS_THROW("Arrays of arrays of samplers are not supported in MSL.");
13404
13405 // Arrays of samplers in MSL must be declared with a special array<T, N> syntax ala C++11 std::array.
13406 // If we have a runtime array, it could be a variable-count descriptor set binding.
13407 uint32_t array_size = to_array_size_literal(type);
13408 if (array_size == 0)
13409 array_size = get_resource_array_size(id);
13410
13411 if (array_size == 0)
13412 SPIRV_CROSS_THROW("Unsized array of samplers is not supported in MSL.");
13413
13414 auto &parent = get<SPIRType>(get_pointee_type(type).parent_type);
13415 return join("array<", sampler_type(parent, id), ", ", array_size, ">");
13416 }
13417 else
13418 return "sampler";
13419 }
13420
13421 // Returns an MSL string describing the SPIR-V image type
image_type_glsl(const SPIRType & type,uint32_t id)13422 string CompilerMSL::image_type_glsl(const SPIRType &type, uint32_t id)
13423 {
13424 auto *var = maybe_get<SPIRVariable>(id);
13425 if (var && var->basevariable)
13426 {
13427 // For comparison images, check against the base variable,
13428 // and not the fake ID which might have been generated for this variable.
13429 id = var->basevariable;
13430 }
13431
13432 if (!type.array.empty())
13433 {
13434 uint32_t major = 2, minor = 0;
13435 if (msl_options.is_ios())
13436 {
13437 major = 1;
13438 minor = 2;
13439 }
13440 if (!msl_options.supports_msl_version(major, minor))
13441 {
13442 if (msl_options.is_ios())
13443 SPIRV_CROSS_THROW("MSL 1.2 or greater is required for arrays of textures.");
13444 else
13445 SPIRV_CROSS_THROW("MSL 2.0 or greater is required for arrays of textures.");
13446 }
13447
13448 if (type.array.size() > 1)
13449 SPIRV_CROSS_THROW("Arrays of arrays of textures are not supported in MSL.");
13450
13451 // Arrays of images in MSL must be declared with a special array<T, N> syntax ala C++11 std::array.
13452 // If we have a runtime array, it could be a variable-count descriptor set binding.
13453 uint32_t array_size = to_array_size_literal(type);
13454 if (array_size == 0)
13455 array_size = get_resource_array_size(id);
13456
13457 if (array_size == 0)
13458 SPIRV_CROSS_THROW("Unsized array of images is not supported in MSL.");
13459
13460 auto &parent = get<SPIRType>(get_pointee_type(type).parent_type);
13461 return join("array<", image_type_glsl(parent, id), ", ", array_size, ">");
13462 }
13463
13464 string img_type_name;
13465
13466 // Bypass pointers because we need the real image struct
13467 auto &img_type = get<SPIRType>(type.self).image;
13468 if (image_is_comparison(type, id))
13469 {
13470 switch (img_type.dim)
13471 {
13472 case Dim1D:
13473 case Dim2D:
13474 if (img_type.dim == Dim1D && !msl_options.texture_1D_as_2D)
13475 {
13476 // Use a native Metal 1D texture
13477 img_type_name += "depth1d_unsupported_by_metal";
13478 break;
13479 }
13480
13481 if (img_type.ms && img_type.arrayed)
13482 {
13483 if (!msl_options.supports_msl_version(2, 1))
13484 SPIRV_CROSS_THROW("Multisampled array textures are supported from 2.1.");
13485 img_type_name += "depth2d_ms_array";
13486 }
13487 else if (img_type.ms)
13488 img_type_name += "depth2d_ms";
13489 else if (img_type.arrayed)
13490 img_type_name += "depth2d_array";
13491 else
13492 img_type_name += "depth2d";
13493 break;
13494 case Dim3D:
13495 img_type_name += "depth3d_unsupported_by_metal";
13496 break;
13497 case DimCube:
13498 if (!msl_options.emulate_cube_array)
13499 img_type_name += (img_type.arrayed ? "depthcube_array" : "depthcube");
13500 else
13501 img_type_name += (img_type.arrayed ? "depth2d_array" : "depthcube");
13502 break;
13503 default:
13504 img_type_name += "unknown_depth_texture_type";
13505 break;
13506 }
13507 }
13508 else
13509 {
13510 switch (img_type.dim)
13511 {
13512 case DimBuffer:
13513 if (img_type.ms || img_type.arrayed)
13514 SPIRV_CROSS_THROW("Cannot use texel buffers with multisampling or array layers.");
13515
13516 if (msl_options.texture_buffer_native)
13517 {
13518 if (!msl_options.supports_msl_version(2, 1))
13519 SPIRV_CROSS_THROW("Native texture_buffer type is only supported in MSL 2.1.");
13520 img_type_name = "texture_buffer";
13521 }
13522 else
13523 img_type_name += "texture2d";
13524 break;
13525 case Dim1D:
13526 case Dim2D:
13527 case DimSubpassData:
13528 {
13529 bool subpass_array =
13530 img_type.dim == DimSubpassData && (msl_options.multiview || msl_options.arrayed_subpass_input);
13531 if (img_type.dim == Dim1D && !msl_options.texture_1D_as_2D)
13532 {
13533 // Use a native Metal 1D texture
13534 img_type_name += (img_type.arrayed ? "texture1d_array" : "texture1d");
13535 break;
13536 }
13537
13538 // Use Metal's native frame-buffer fetch API for subpass inputs.
13539 if (type_is_msl_framebuffer_fetch(type))
13540 {
13541 auto img_type_4 = get<SPIRType>(img_type.type);
13542 img_type_4.vecsize = 4;
13543 return type_to_glsl(img_type_4);
13544 }
13545 if (img_type.ms && (img_type.arrayed || subpass_array))
13546 {
13547 if (!msl_options.supports_msl_version(2, 1))
13548 SPIRV_CROSS_THROW("Multisampled array textures are supported from 2.1.");
13549 img_type_name += "texture2d_ms_array";
13550 }
13551 else if (img_type.ms)
13552 img_type_name += "texture2d_ms";
13553 else if (img_type.arrayed || subpass_array)
13554 img_type_name += "texture2d_array";
13555 else
13556 img_type_name += "texture2d";
13557 break;
13558 }
13559 case Dim3D:
13560 img_type_name += "texture3d";
13561 break;
13562 case DimCube:
13563 if (!msl_options.emulate_cube_array)
13564 img_type_name += (img_type.arrayed ? "texturecube_array" : "texturecube");
13565 else
13566 img_type_name += (img_type.arrayed ? "texture2d_array" : "texturecube");
13567 break;
13568 default:
13569 img_type_name += "unknown_texture_type";
13570 break;
13571 }
13572 }
13573
13574 // Append the pixel type
13575 img_type_name += "<";
13576 img_type_name += type_to_glsl(get<SPIRType>(img_type.type));
13577
13578 // For unsampled images, append the sample/read/write access qualifier.
13579 // For kernel images, the access qualifier my be supplied directly by SPIR-V.
13580 // Otherwise it may be set based on whether the image is read from or written to within the shader.
13581 if (type.basetype == SPIRType::Image && type.image.sampled == 2 && type.image.dim != DimSubpassData)
13582 {
13583 switch (img_type.access)
13584 {
13585 case AccessQualifierReadOnly:
13586 img_type_name += ", access::read";
13587 break;
13588
13589 case AccessQualifierWriteOnly:
13590 img_type_name += ", access::write";
13591 break;
13592
13593 case AccessQualifierReadWrite:
13594 img_type_name += ", access::read_write";
13595 break;
13596
13597 default:
13598 {
13599 auto *p_var = maybe_get_backing_variable(id);
13600 if (p_var && p_var->basevariable)
13601 p_var = maybe_get<SPIRVariable>(p_var->basevariable);
13602 if (p_var && !has_decoration(p_var->self, DecorationNonWritable))
13603 {
13604 img_type_name += ", access::";
13605
13606 if (!has_decoration(p_var->self, DecorationNonReadable))
13607 img_type_name += "read_";
13608
13609 img_type_name += "write";
13610 }
13611 break;
13612 }
13613 }
13614 }
13615
13616 img_type_name += ">";
13617
13618 return img_type_name;
13619 }
13620
emit_subgroup_op(const Instruction & i)13621 void CompilerMSL::emit_subgroup_op(const Instruction &i)
13622 {
13623 const uint32_t *ops = stream(i);
13624 auto op = static_cast<Op>(i.op);
13625
13626 if (msl_options.emulate_subgroups)
13627 {
13628 // In this mode, only the GroupNonUniform cap is supported. The only op
13629 // we need to handle, then, is OpGroupNonUniformElect.
13630 if (op != OpGroupNonUniformElect)
13631 SPIRV_CROSS_THROW("Subgroup emulation does not support operations other than Elect.");
13632 // In this mode, the subgroup size is assumed to be one, so every invocation
13633 // is elected.
13634 emit_op(ops[0], ops[1], "true", true);
13635 return;
13636 }
13637
13638 // Metal 2.0 is required. iOS only supports quad ops on 11.0 (2.0), with
13639 // full support in 13.0 (2.2). macOS only supports broadcast and shuffle on
13640 // 10.13 (2.0), with full support in 10.14 (2.1).
13641 // Note that Apple GPUs before A13 make no distinction between a quad-group
13642 // and a SIMD-group; all SIMD-groups are quad-groups on those.
13643 if (!msl_options.supports_msl_version(2))
13644 SPIRV_CROSS_THROW("Subgroups are only supported in Metal 2.0 and up.");
13645
13646 // If we need to do implicit bitcasts, make sure we do it with the correct type.
13647 uint32_t integer_width = get_integer_width_for_instruction(i);
13648 auto int_type = to_signed_basetype(integer_width);
13649 auto uint_type = to_unsigned_basetype(integer_width);
13650
13651 if (msl_options.is_ios() && (!msl_options.supports_msl_version(2, 3) || !msl_options.ios_use_simdgroup_functions))
13652 {
13653 switch (op)
13654 {
13655 default:
13656 SPIRV_CROSS_THROW("Subgroup ops beyond broadcast, ballot, and shuffle on iOS require Metal 2.3 and up.");
13657 case OpGroupNonUniformBroadcastFirst:
13658 if (!msl_options.supports_msl_version(2, 2))
13659 SPIRV_CROSS_THROW("BroadcastFirst on iOS requires Metal 2.2 and up.");
13660 break;
13661 case OpGroupNonUniformElect:
13662 if (!msl_options.supports_msl_version(2, 2))
13663 SPIRV_CROSS_THROW("Elect on iOS requires Metal 2.2 and up.");
13664 break;
13665 case OpGroupNonUniformAny:
13666 case OpGroupNonUniformAll:
13667 case OpGroupNonUniformAllEqual:
13668 case OpGroupNonUniformBallot:
13669 case OpGroupNonUniformInverseBallot:
13670 case OpGroupNonUniformBallotBitExtract:
13671 case OpGroupNonUniformBallotFindLSB:
13672 case OpGroupNonUniformBallotFindMSB:
13673 case OpGroupNonUniformBallotBitCount:
13674 if (!msl_options.supports_msl_version(2, 2))
13675 SPIRV_CROSS_THROW("Ballot ops on iOS requires Metal 2.2 and up.");
13676 break;
13677 case OpGroupNonUniformBroadcast:
13678 case OpGroupNonUniformShuffle:
13679 case OpGroupNonUniformShuffleXor:
13680 case OpGroupNonUniformShuffleUp:
13681 case OpGroupNonUniformShuffleDown:
13682 case OpGroupNonUniformQuadSwap:
13683 case OpGroupNonUniformQuadBroadcast:
13684 break;
13685 }
13686 }
13687
13688 if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 1))
13689 {
13690 switch (op)
13691 {
13692 default:
13693 SPIRV_CROSS_THROW("Subgroup ops beyond broadcast and shuffle on macOS require Metal 2.1 and up.");
13694 case OpGroupNonUniformBroadcast:
13695 case OpGroupNonUniformShuffle:
13696 case OpGroupNonUniformShuffleXor:
13697 case OpGroupNonUniformShuffleUp:
13698 case OpGroupNonUniformShuffleDown:
13699 break;
13700 }
13701 }
13702
13703 uint32_t result_type = ops[0];
13704 uint32_t id = ops[1];
13705
13706 auto scope = static_cast<Scope>(evaluate_constant_u32(ops[2]));
13707 if (scope != ScopeSubgroup)
13708 SPIRV_CROSS_THROW("Only subgroup scope is supported.");
13709
13710 switch (op)
13711 {
13712 case OpGroupNonUniformElect:
13713 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
13714 emit_op(result_type, id, "quad_is_first()", false);
13715 else
13716 emit_op(result_type, id, "simd_is_first()", false);
13717 break;
13718
13719 case OpGroupNonUniformBroadcast:
13720 emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupBroadcast");
13721 break;
13722
13723 case OpGroupNonUniformBroadcastFirst:
13724 emit_unary_func_op(result_type, id, ops[3], "spvSubgroupBroadcastFirst");
13725 break;
13726
13727 case OpGroupNonUniformBallot:
13728 emit_unary_func_op(result_type, id, ops[3], "spvSubgroupBallot");
13729 break;
13730
13731 case OpGroupNonUniformInverseBallot:
13732 emit_binary_func_op(result_type, id, ops[3], builtin_subgroup_invocation_id_id, "spvSubgroupBallotBitExtract");
13733 break;
13734
13735 case OpGroupNonUniformBallotBitExtract:
13736 emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupBallotBitExtract");
13737 break;
13738
13739 case OpGroupNonUniformBallotFindLSB:
13740 emit_binary_func_op(result_type, id, ops[3], builtin_subgroup_size_id, "spvSubgroupBallotFindLSB");
13741 break;
13742
13743 case OpGroupNonUniformBallotFindMSB:
13744 emit_binary_func_op(result_type, id, ops[3], builtin_subgroup_size_id, "spvSubgroupBallotFindMSB");
13745 break;
13746
13747 case OpGroupNonUniformBallotBitCount:
13748 {
13749 auto operation = static_cast<GroupOperation>(ops[3]);
13750 switch (operation)
13751 {
13752 case GroupOperationReduce:
13753 emit_binary_func_op(result_type, id, ops[4], builtin_subgroup_size_id, "spvSubgroupBallotBitCount");
13754 break;
13755 case GroupOperationInclusiveScan:
13756 emit_binary_func_op(result_type, id, ops[4], builtin_subgroup_invocation_id_id,
13757 "spvSubgroupBallotInclusiveBitCount");
13758 break;
13759 case GroupOperationExclusiveScan:
13760 emit_binary_func_op(result_type, id, ops[4], builtin_subgroup_invocation_id_id,
13761 "spvSubgroupBallotExclusiveBitCount");
13762 break;
13763 default:
13764 SPIRV_CROSS_THROW("Invalid BitCount operation.");
13765 }
13766 break;
13767 }
13768
13769 case OpGroupNonUniformShuffle:
13770 emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupShuffle");
13771 break;
13772
13773 case OpGroupNonUniformShuffleXor:
13774 emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupShuffleXor");
13775 break;
13776
13777 case OpGroupNonUniformShuffleUp:
13778 emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupShuffleUp");
13779 break;
13780
13781 case OpGroupNonUniformShuffleDown:
13782 emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupShuffleDown");
13783 break;
13784
13785 case OpGroupNonUniformAll:
13786 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
13787 emit_unary_func_op(result_type, id, ops[3], "quad_all");
13788 else
13789 emit_unary_func_op(result_type, id, ops[3], "simd_all");
13790 break;
13791
13792 case OpGroupNonUniformAny:
13793 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
13794 emit_unary_func_op(result_type, id, ops[3], "quad_any");
13795 else
13796 emit_unary_func_op(result_type, id, ops[3], "simd_any");
13797 break;
13798
13799 case OpGroupNonUniformAllEqual:
13800 emit_unary_func_op(result_type, id, ops[3], "spvSubgroupAllEqual");
13801 break;
13802
13803 // clang-format off
13804 #define MSL_GROUP_OP(op, msl_op) \
13805 case OpGroupNonUniform##op: \
13806 { \
13807 auto operation = static_cast<GroupOperation>(ops[3]); \
13808 if (operation == GroupOperationReduce) \
13809 emit_unary_func_op(result_type, id, ops[4], "simd_" #msl_op); \
13810 else if (operation == GroupOperationInclusiveScan) \
13811 emit_unary_func_op(result_type, id, ops[4], "simd_prefix_inclusive_" #msl_op); \
13812 else if (operation == GroupOperationExclusiveScan) \
13813 emit_unary_func_op(result_type, id, ops[4], "simd_prefix_exclusive_" #msl_op); \
13814 else if (operation == GroupOperationClusteredReduce) \
13815 { \
13816 /* Only cluster sizes of 4 are supported. */ \
13817 uint32_t cluster_size = evaluate_constant_u32(ops[5]); \
13818 if (cluster_size != 4) \
13819 SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
13820 emit_unary_func_op(result_type, id, ops[4], "quad_" #msl_op); \
13821 } \
13822 else \
13823 SPIRV_CROSS_THROW("Invalid group operation."); \
13824 break; \
13825 }
13826 MSL_GROUP_OP(FAdd, sum)
13827 MSL_GROUP_OP(FMul, product)
13828 MSL_GROUP_OP(IAdd, sum)
13829 MSL_GROUP_OP(IMul, product)
13830 #undef MSL_GROUP_OP
13831 // The others, unfortunately, don't support InclusiveScan or ExclusiveScan.
13832
13833 #define MSL_GROUP_OP(op, msl_op) \
13834 case OpGroupNonUniform##op: \
13835 { \
13836 auto operation = static_cast<GroupOperation>(ops[3]); \
13837 if (operation == GroupOperationReduce) \
13838 emit_unary_func_op(result_type, id, ops[4], "simd_" #msl_op); \
13839 else if (operation == GroupOperationInclusiveScan) \
13840 SPIRV_CROSS_THROW("Metal doesn't support InclusiveScan for OpGroupNonUniform" #op "."); \
13841 else if (operation == GroupOperationExclusiveScan) \
13842 SPIRV_CROSS_THROW("Metal doesn't support ExclusiveScan for OpGroupNonUniform" #op "."); \
13843 else if (operation == GroupOperationClusteredReduce) \
13844 { \
13845 /* Only cluster sizes of 4 are supported. */ \
13846 uint32_t cluster_size = evaluate_constant_u32(ops[5]); \
13847 if (cluster_size != 4) \
13848 SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
13849 emit_unary_func_op(result_type, id, ops[4], "quad_" #msl_op); \
13850 } \
13851 else \
13852 SPIRV_CROSS_THROW("Invalid group operation."); \
13853 break; \
13854 }
13855
13856 #define MSL_GROUP_OP_CAST(op, msl_op, type) \
13857 case OpGroupNonUniform##op: \
13858 { \
13859 auto operation = static_cast<GroupOperation>(ops[3]); \
13860 if (operation == GroupOperationReduce) \
13861 emit_unary_func_op_cast(result_type, id, ops[4], "simd_" #msl_op, type, type); \
13862 else if (operation == GroupOperationInclusiveScan) \
13863 SPIRV_CROSS_THROW("Metal doesn't support InclusiveScan for OpGroupNonUniform" #op "."); \
13864 else if (operation == GroupOperationExclusiveScan) \
13865 SPIRV_CROSS_THROW("Metal doesn't support ExclusiveScan for OpGroupNonUniform" #op "."); \
13866 else if (operation == GroupOperationClusteredReduce) \
13867 { \
13868 /* Only cluster sizes of 4 are supported. */ \
13869 uint32_t cluster_size = evaluate_constant_u32(ops[5]); \
13870 if (cluster_size != 4) \
13871 SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
13872 emit_unary_func_op_cast(result_type, id, ops[4], "quad_" #msl_op, type, type); \
13873 } \
13874 else \
13875 SPIRV_CROSS_THROW("Invalid group operation."); \
13876 break; \
13877 }
13878
13879 MSL_GROUP_OP(FMin, min)
13880 MSL_GROUP_OP(FMax, max)
13881 MSL_GROUP_OP_CAST(SMin, min, int_type)
13882 MSL_GROUP_OP_CAST(SMax, max, int_type)
13883 MSL_GROUP_OP_CAST(UMin, min, uint_type)
13884 MSL_GROUP_OP_CAST(UMax, max, uint_type)
13885 MSL_GROUP_OP(BitwiseAnd, and)
13886 MSL_GROUP_OP(BitwiseOr, or)
13887 MSL_GROUP_OP(BitwiseXor, xor)
13888 MSL_GROUP_OP(LogicalAnd, and)
13889 MSL_GROUP_OP(LogicalOr, or)
13890 MSL_GROUP_OP(LogicalXor, xor)
13891 // clang-format on
13892 #undef MSL_GROUP_OP
13893 #undef MSL_GROUP_OP_CAST
13894
13895 case OpGroupNonUniformQuadSwap:
13896 emit_binary_func_op(result_type, id, ops[3], ops[4], "spvQuadSwap");
13897 break;
13898
13899 case OpGroupNonUniformQuadBroadcast:
13900 emit_binary_func_op(result_type, id, ops[3], ops[4], "spvQuadBroadcast");
13901 break;
13902
13903 default:
13904 SPIRV_CROSS_THROW("Invalid opcode for subgroup.");
13905 }
13906
13907 register_control_dependent_expression(id);
13908 }
13909
bitcast_glsl_op(const SPIRType & out_type,const SPIRType & in_type)13910 string CompilerMSL::bitcast_glsl_op(const SPIRType &out_type, const SPIRType &in_type)
13911 {
13912 if (out_type.basetype == in_type.basetype)
13913 return "";
13914
13915 assert(out_type.basetype != SPIRType::Boolean);
13916 assert(in_type.basetype != SPIRType::Boolean);
13917
13918 bool integral_cast = type_is_integral(out_type) && type_is_integral(in_type) && (out_type.vecsize == in_type.vecsize);
13919 bool same_size_cast = (out_type.width * out_type.vecsize) == (in_type.width * in_type.vecsize);
13920
13921 // Bitcasting can only be used between types of the same overall size.
13922 // And always formally cast between integers, because it's trivial, and also
13923 // because Metal can internally cast the results of some integer ops to a larger
13924 // size (eg. short shift right becomes int), which means chaining integer ops
13925 // together may introduce size variations that SPIR-V doesn't know about.
13926 if (same_size_cast && !integral_cast)
13927 {
13928 return "as_type<" + type_to_glsl(out_type) + ">";
13929 }
13930 else
13931 {
13932 return type_to_glsl(out_type);
13933 }
13934 }
13935
emit_complex_bitcast(uint32_t,uint32_t,uint32_t)13936 bool CompilerMSL::emit_complex_bitcast(uint32_t, uint32_t, uint32_t)
13937 {
13938 return false;
13939 }
13940
13941 // Returns an MSL string identifying the name of a SPIR-V builtin.
13942 // Output builtins are qualified with the name of the stage out structure.
builtin_to_glsl(BuiltIn builtin,StorageClass storage)13943 string CompilerMSL::builtin_to_glsl(BuiltIn builtin, StorageClass storage)
13944 {
13945 switch (builtin)
13946 {
13947 // Handle HLSL-style 0-based vertex/instance index.
13948 // Override GLSL compiler strictness
13949 case BuiltInVertexId:
13950 ensure_builtin(StorageClassInput, BuiltInVertexId);
13951 if (msl_options.enable_base_index_zero && msl_options.supports_msl_version(1, 1) &&
13952 (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
13953 {
13954 if (builtin_declaration)
13955 {
13956 if (needs_base_vertex_arg != TriState::No)
13957 needs_base_vertex_arg = TriState::Yes;
13958 return "gl_VertexID";
13959 }
13960 else
13961 {
13962 ensure_builtin(StorageClassInput, BuiltInBaseVertex);
13963 return "(gl_VertexID - gl_BaseVertex)";
13964 }
13965 }
13966 else
13967 {
13968 return "gl_VertexID";
13969 }
13970 case BuiltInInstanceId:
13971 ensure_builtin(StorageClassInput, BuiltInInstanceId);
13972 if (msl_options.enable_base_index_zero && msl_options.supports_msl_version(1, 1) &&
13973 (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
13974 {
13975 if (builtin_declaration)
13976 {
13977 if (needs_base_instance_arg != TriState::No)
13978 needs_base_instance_arg = TriState::Yes;
13979 return "gl_InstanceID";
13980 }
13981 else
13982 {
13983 ensure_builtin(StorageClassInput, BuiltInBaseInstance);
13984 return "(gl_InstanceID - gl_BaseInstance)";
13985 }
13986 }
13987 else
13988 {
13989 return "gl_InstanceID";
13990 }
13991 case BuiltInVertexIndex:
13992 ensure_builtin(StorageClassInput, BuiltInVertexIndex);
13993 if (msl_options.enable_base_index_zero && msl_options.supports_msl_version(1, 1) &&
13994 (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
13995 {
13996 if (builtin_declaration)
13997 {
13998 if (needs_base_vertex_arg != TriState::No)
13999 needs_base_vertex_arg = TriState::Yes;
14000 return "gl_VertexIndex";
14001 }
14002 else
14003 {
14004 ensure_builtin(StorageClassInput, BuiltInBaseVertex);
14005 return "(gl_VertexIndex - gl_BaseVertex)";
14006 }
14007 }
14008 else
14009 {
14010 return "gl_VertexIndex";
14011 }
14012 case BuiltInInstanceIndex:
14013 ensure_builtin(StorageClassInput, BuiltInInstanceIndex);
14014 if (msl_options.enable_base_index_zero && msl_options.supports_msl_version(1, 1) &&
14015 (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
14016 {
14017 if (builtin_declaration)
14018 {
14019 if (needs_base_instance_arg != TriState::No)
14020 needs_base_instance_arg = TriState::Yes;
14021 return "gl_InstanceIndex";
14022 }
14023 else
14024 {
14025 ensure_builtin(StorageClassInput, BuiltInBaseInstance);
14026 return "(gl_InstanceIndex - gl_BaseInstance)";
14027 }
14028 }
14029 else
14030 {
14031 return "gl_InstanceIndex";
14032 }
14033 case BuiltInBaseVertex:
14034 if (msl_options.supports_msl_version(1, 1) &&
14035 (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
14036 {
14037 needs_base_vertex_arg = TriState::No;
14038 return "gl_BaseVertex";
14039 }
14040 else
14041 {
14042 SPIRV_CROSS_THROW("BaseVertex requires Metal 1.1 and Mac or Apple A9+ hardware.");
14043 }
14044 case BuiltInBaseInstance:
14045 if (msl_options.supports_msl_version(1, 1) &&
14046 (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
14047 {
14048 needs_base_instance_arg = TriState::No;
14049 return "gl_BaseInstance";
14050 }
14051 else
14052 {
14053 SPIRV_CROSS_THROW("BaseInstance requires Metal 1.1 and Mac or Apple A9+ hardware.");
14054 }
14055 case BuiltInDrawIndex:
14056 SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
14057
14058 // When used in the entry function, output builtins are qualified with output struct name.
14059 // Test storage class as NOT Input, as output builtins might be part of generic type.
14060 // Also don't do this for tessellation control shaders.
14061 case BuiltInViewportIndex:
14062 if (!msl_options.supports_msl_version(2, 0))
14063 SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
14064 /* fallthrough */
14065 case BuiltInFragDepth:
14066 case BuiltInFragStencilRefEXT:
14067 if ((builtin == BuiltInFragDepth && !msl_options.enable_frag_depth_builtin) ||
14068 (builtin == BuiltInFragStencilRefEXT && !msl_options.enable_frag_stencil_ref_builtin))
14069 break;
14070 /* fallthrough */
14071 case BuiltInPosition:
14072 case BuiltInPointSize:
14073 case BuiltInClipDistance:
14074 case BuiltInCullDistance:
14075 case BuiltInLayer:
14076 if (get_execution_model() == ExecutionModelTessellationControl)
14077 break;
14078 if (storage != StorageClassInput && current_function && (current_function->self == ir.default_entry_point) &&
14079 !is_stage_output_builtin_masked(builtin))
14080 return stage_out_var_name + "." + CompilerGLSL::builtin_to_glsl(builtin, storage);
14081 break;
14082
14083 case BuiltInSampleMask:
14084 if (storage == StorageClassInput && current_function && (current_function->self == ir.default_entry_point) &&
14085 (has_additional_fixed_sample_mask() || needs_sample_id))
14086 {
14087 string samp_mask_in;
14088 samp_mask_in += "(" + CompilerGLSL::builtin_to_glsl(builtin, storage);
14089 if (has_additional_fixed_sample_mask())
14090 samp_mask_in += " & " + additional_fixed_sample_mask_str();
14091 if (needs_sample_id)
14092 samp_mask_in += " & (1 << gl_SampleID)";
14093 samp_mask_in += ")";
14094 return samp_mask_in;
14095 }
14096 if (storage != StorageClassInput && current_function && (current_function->self == ir.default_entry_point) &&
14097 !is_stage_output_builtin_masked(builtin))
14098 return stage_out_var_name + "." + CompilerGLSL::builtin_to_glsl(builtin, storage);
14099 break;
14100
14101 case BuiltInBaryCoordNV:
14102 case BuiltInBaryCoordNoPerspNV:
14103 if (storage == StorageClassInput && current_function && (current_function->self == ir.default_entry_point))
14104 return stage_in_var_name + "." + CompilerGLSL::builtin_to_glsl(builtin, storage);
14105 break;
14106
14107 case BuiltInTessLevelOuter:
14108 if (get_execution_model() == ExecutionModelTessellationControl &&
14109 storage != StorageClassInput && current_function && (current_function->self == ir.default_entry_point))
14110 {
14111 return join(tess_factor_buffer_var_name, "[", to_expression(builtin_primitive_id_id),
14112 "].edgeTessellationFactor");
14113 }
14114 break;
14115
14116 case BuiltInTessLevelInner:
14117 if (get_execution_model() == ExecutionModelTessellationControl &&
14118 storage != StorageClassInput && current_function && (current_function->self == ir.default_entry_point))
14119 {
14120 return join(tess_factor_buffer_var_name, "[", to_expression(builtin_primitive_id_id),
14121 "].insideTessellationFactor");
14122 }
14123 break;
14124
14125 default:
14126 break;
14127 }
14128
14129 return CompilerGLSL::builtin_to_glsl(builtin, storage);
14130 }
14131
14132 // Returns an MSL string attribute qualifer for a SPIR-V builtin
builtin_qualifier(BuiltIn builtin)14133 string CompilerMSL::builtin_qualifier(BuiltIn builtin)
14134 {
14135 auto &execution = get_entry_point();
14136
14137 switch (builtin)
14138 {
14139 // Vertex function in
14140 case BuiltInVertexId:
14141 return "vertex_id";
14142 case BuiltInVertexIndex:
14143 return "vertex_id";
14144 case BuiltInBaseVertex:
14145 return "base_vertex";
14146 case BuiltInInstanceId:
14147 return "instance_id";
14148 case BuiltInInstanceIndex:
14149 return "instance_id";
14150 case BuiltInBaseInstance:
14151 return "base_instance";
14152 case BuiltInDrawIndex:
14153 SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
14154
14155 // Vertex function out
14156 case BuiltInClipDistance:
14157 return "clip_distance";
14158 case BuiltInPointSize:
14159 return "point_size";
14160 case BuiltInPosition:
14161 if (position_invariant)
14162 {
14163 if (!msl_options.supports_msl_version(2, 1))
14164 SPIRV_CROSS_THROW("Invariant position is only supported on MSL 2.1 and up.");
14165 return "position, invariant";
14166 }
14167 else
14168 return "position";
14169 case BuiltInLayer:
14170 return "render_target_array_index";
14171 case BuiltInViewportIndex:
14172 if (!msl_options.supports_msl_version(2, 0))
14173 SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
14174 return "viewport_array_index";
14175
14176 // Tess. control function in
14177 case BuiltInInvocationId:
14178 if (msl_options.multi_patch_workgroup)
14179 {
14180 // Shouldn't be reached.
14181 SPIRV_CROSS_THROW("InvocationId is computed manually with multi-patch workgroups in MSL.");
14182 }
14183 return "thread_index_in_threadgroup";
14184 case BuiltInPatchVertices:
14185 // Shouldn't be reached.
14186 SPIRV_CROSS_THROW("PatchVertices is derived from the auxiliary buffer in MSL.");
14187 case BuiltInPrimitiveId:
14188 switch (execution.model)
14189 {
14190 case ExecutionModelTessellationControl:
14191 if (msl_options.multi_patch_workgroup)
14192 {
14193 // Shouldn't be reached.
14194 SPIRV_CROSS_THROW("PrimitiveId is computed manually with multi-patch workgroups in MSL.");
14195 }
14196 return "threadgroup_position_in_grid";
14197 case ExecutionModelTessellationEvaluation:
14198 return "patch_id";
14199 case ExecutionModelFragment:
14200 if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 3))
14201 SPIRV_CROSS_THROW("PrimitiveId on iOS requires MSL 2.3.");
14202 else if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 2))
14203 SPIRV_CROSS_THROW("PrimitiveId on macOS requires MSL 2.2.");
14204 return "primitive_id";
14205 default:
14206 SPIRV_CROSS_THROW("PrimitiveId is not supported in this execution model.");
14207 }
14208
14209 // Tess. control function out
14210 case BuiltInTessLevelOuter:
14211 case BuiltInTessLevelInner:
14212 // Shouldn't be reached.
14213 SPIRV_CROSS_THROW("Tessellation levels are handled specially in MSL.");
14214
14215 // Tess. evaluation function in
14216 case BuiltInTessCoord:
14217 return "position_in_patch";
14218
14219 // Fragment function in
14220 case BuiltInFrontFacing:
14221 return "front_facing";
14222 case BuiltInPointCoord:
14223 return "point_coord";
14224 case BuiltInFragCoord:
14225 return "position";
14226 case BuiltInSampleId:
14227 return "sample_id";
14228 case BuiltInSampleMask:
14229 return "sample_mask";
14230 case BuiltInSamplePosition:
14231 // Shouldn't be reached.
14232 SPIRV_CROSS_THROW("Sample position is retrieved by a function in MSL.");
14233 case BuiltInViewIndex:
14234 if (execution.model != ExecutionModelFragment)
14235 SPIRV_CROSS_THROW("ViewIndex is handled specially outside fragment shaders.");
14236 // The ViewIndex was implicitly used in the prior stages to set the render_target_array_index,
14237 // so we can get it from there.
14238 return "render_target_array_index";
14239
14240 // Fragment function out
14241 case BuiltInFragDepth:
14242 if (execution.flags.get(ExecutionModeDepthGreater))
14243 return "depth(greater)";
14244 else if (execution.flags.get(ExecutionModeDepthLess))
14245 return "depth(less)";
14246 else
14247 return "depth(any)";
14248
14249 case BuiltInFragStencilRefEXT:
14250 return "stencil";
14251
14252 // Compute function in
14253 case BuiltInGlobalInvocationId:
14254 return "thread_position_in_grid";
14255
14256 case BuiltInWorkgroupId:
14257 return "threadgroup_position_in_grid";
14258
14259 case BuiltInNumWorkgroups:
14260 return "threadgroups_per_grid";
14261
14262 case BuiltInLocalInvocationId:
14263 return "thread_position_in_threadgroup";
14264
14265 case BuiltInLocalInvocationIndex:
14266 return "thread_index_in_threadgroup";
14267
14268 case BuiltInSubgroupSize:
14269 if (msl_options.emulate_subgroups || msl_options.fixed_subgroup_size != 0)
14270 // Shouldn't be reached.
14271 SPIRV_CROSS_THROW("Emitting threads_per_simdgroup attribute with fixed subgroup size??");
14272 if (execution.model == ExecutionModelFragment)
14273 {
14274 if (!msl_options.supports_msl_version(2, 2))
14275 SPIRV_CROSS_THROW("threads_per_simdgroup requires Metal 2.2 in fragment shaders.");
14276 return "threads_per_simdgroup";
14277 }
14278 else
14279 {
14280 // thread_execution_width is an alias for threads_per_simdgroup, and it's only available since 1.0,
14281 // but not in fragment.
14282 return "thread_execution_width";
14283 }
14284
14285 case BuiltInNumSubgroups:
14286 if (msl_options.emulate_subgroups)
14287 // Shouldn't be reached.
14288 SPIRV_CROSS_THROW("NumSubgroups is handled specially with emulation.");
14289 if (!msl_options.supports_msl_version(2))
14290 SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0.");
14291 return msl_options.is_ios() ? "quadgroups_per_threadgroup" : "simdgroups_per_threadgroup";
14292
14293 case BuiltInSubgroupId:
14294 if (msl_options.emulate_subgroups)
14295 // Shouldn't be reached.
14296 SPIRV_CROSS_THROW("SubgroupId is handled specially with emulation.");
14297 if (!msl_options.supports_msl_version(2))
14298 SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0.");
14299 return msl_options.is_ios() ? "quadgroup_index_in_threadgroup" : "simdgroup_index_in_threadgroup";
14300
14301 case BuiltInSubgroupLocalInvocationId:
14302 if (msl_options.emulate_subgroups)
14303 // Shouldn't be reached.
14304 SPIRV_CROSS_THROW("SubgroupLocalInvocationId is handled specially with emulation.");
14305 if (execution.model == ExecutionModelFragment)
14306 {
14307 if (!msl_options.supports_msl_version(2, 2))
14308 SPIRV_CROSS_THROW("thread_index_in_simdgroup requires Metal 2.2 in fragment shaders.");
14309 return "thread_index_in_simdgroup";
14310 }
14311 else if (execution.model == ExecutionModelKernel || execution.model == ExecutionModelGLCompute ||
14312 execution.model == ExecutionModelTessellationControl ||
14313 (execution.model == ExecutionModelVertex && msl_options.vertex_for_tessellation))
14314 {
14315 // We are generating a Metal kernel function.
14316 if (!msl_options.supports_msl_version(2))
14317 SPIRV_CROSS_THROW("Subgroup builtins in kernel functions require Metal 2.0.");
14318 return msl_options.is_ios() ? "thread_index_in_quadgroup" : "thread_index_in_simdgroup";
14319 }
14320 else
14321 SPIRV_CROSS_THROW("Subgroup builtins are not available in this type of function.");
14322
14323 case BuiltInSubgroupEqMask:
14324 case BuiltInSubgroupGeMask:
14325 case BuiltInSubgroupGtMask:
14326 case BuiltInSubgroupLeMask:
14327 case BuiltInSubgroupLtMask:
14328 // Shouldn't be reached.
14329 SPIRV_CROSS_THROW("Subgroup ballot masks are handled specially in MSL.");
14330
14331 case BuiltInBaryCoordNV:
14332 // TODO: AMD barycentrics as well? Seem to have different swizzle and 2 components rather than 3.
14333 if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 3))
14334 SPIRV_CROSS_THROW("Barycentrics are only supported in MSL 2.3 and above on iOS.");
14335 else if (!msl_options.supports_msl_version(2, 2))
14336 SPIRV_CROSS_THROW("Barycentrics are only supported in MSL 2.2 and above on macOS.");
14337 return "barycentric_coord, center_perspective";
14338
14339 case BuiltInBaryCoordNoPerspNV:
14340 // TODO: AMD barycentrics as well? Seem to have different swizzle and 2 components rather than 3.
14341 if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 3))
14342 SPIRV_CROSS_THROW("Barycentrics are only supported in MSL 2.3 and above on iOS.");
14343 else if (!msl_options.supports_msl_version(2, 2))
14344 SPIRV_CROSS_THROW("Barycentrics are only supported in MSL 2.2 and above on macOS.");
14345 return "barycentric_coord, center_no_perspective";
14346
14347 default:
14348 return "unsupported-built-in";
14349 }
14350 }
14351
14352 // Returns an MSL string type declaration for a SPIR-V builtin
builtin_type_decl(BuiltIn builtin,uint32_t id)14353 string CompilerMSL::builtin_type_decl(BuiltIn builtin, uint32_t id)
14354 {
14355 const SPIREntryPoint &execution = get_entry_point();
14356 switch (builtin)
14357 {
14358 // Vertex function in
14359 case BuiltInVertexId:
14360 return "uint";
14361 case BuiltInVertexIndex:
14362 return "uint";
14363 case BuiltInBaseVertex:
14364 return "uint";
14365 case BuiltInInstanceId:
14366 return "uint";
14367 case BuiltInInstanceIndex:
14368 return "uint";
14369 case BuiltInBaseInstance:
14370 return "uint";
14371 case BuiltInDrawIndex:
14372 SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
14373
14374 // Vertex function out
14375 case BuiltInClipDistance:
14376 case BuiltInCullDistance:
14377 return "float";
14378 case BuiltInPointSize:
14379 return "float";
14380 case BuiltInPosition:
14381 return "float4";
14382 case BuiltInLayer:
14383 return "uint";
14384 case BuiltInViewportIndex:
14385 if (!msl_options.supports_msl_version(2, 0))
14386 SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
14387 return "uint";
14388
14389 // Tess. control function in
14390 case BuiltInInvocationId:
14391 return "uint";
14392 case BuiltInPatchVertices:
14393 return "uint";
14394 case BuiltInPrimitiveId:
14395 return "uint";
14396
14397 // Tess. control function out
14398 case BuiltInTessLevelInner:
14399 if (execution.model == ExecutionModelTessellationEvaluation)
14400 return !execution.flags.get(ExecutionModeTriangles) ? "float2" : "float";
14401 return "half";
14402 case BuiltInTessLevelOuter:
14403 if (execution.model == ExecutionModelTessellationEvaluation)
14404 return !execution.flags.get(ExecutionModeTriangles) ? "float4" : "float";
14405 return "half";
14406
14407 // Tess. evaluation function in
14408 case BuiltInTessCoord:
14409 return execution.flags.get(ExecutionModeTriangles) ? "float3" : "float2";
14410
14411 // Fragment function in
14412 case BuiltInFrontFacing:
14413 return "bool";
14414 case BuiltInPointCoord:
14415 return "float2";
14416 case BuiltInFragCoord:
14417 return "float4";
14418 case BuiltInSampleId:
14419 return "uint";
14420 case BuiltInSampleMask:
14421 return "uint";
14422 case BuiltInSamplePosition:
14423 return "float2";
14424 case BuiltInViewIndex:
14425 return "uint";
14426
14427 case BuiltInHelperInvocation:
14428 return "bool";
14429
14430 case BuiltInBaryCoordNV:
14431 case BuiltInBaryCoordNoPerspNV:
14432 // Use the type as declared, can be 1, 2 or 3 components.
14433 return type_to_glsl(get_variable_data_type(get<SPIRVariable>(id)));
14434
14435 // Fragment function out
14436 case BuiltInFragDepth:
14437 return "float";
14438
14439 case BuiltInFragStencilRefEXT:
14440 return "uint";
14441
14442 // Compute function in
14443 case BuiltInGlobalInvocationId:
14444 case BuiltInLocalInvocationId:
14445 case BuiltInNumWorkgroups:
14446 case BuiltInWorkgroupId:
14447 return "uint3";
14448 case BuiltInLocalInvocationIndex:
14449 case BuiltInNumSubgroups:
14450 case BuiltInSubgroupId:
14451 case BuiltInSubgroupSize:
14452 case BuiltInSubgroupLocalInvocationId:
14453 return "uint";
14454 case BuiltInSubgroupEqMask:
14455 case BuiltInSubgroupGeMask:
14456 case BuiltInSubgroupGtMask:
14457 case BuiltInSubgroupLeMask:
14458 case BuiltInSubgroupLtMask:
14459 return "uint4";
14460
14461 case BuiltInDeviceIndex:
14462 return "int";
14463
14464 default:
14465 return "unsupported-built-in-type";
14466 }
14467 }
14468
14469 // Returns the declaration of a built-in argument to a function
built_in_func_arg(BuiltIn builtin,bool prefix_comma)14470 string CompilerMSL::built_in_func_arg(BuiltIn builtin, bool prefix_comma)
14471 {
14472 string bi_arg;
14473 if (prefix_comma)
14474 bi_arg += ", ";
14475
14476 // Handle HLSL-style 0-based vertex/instance index.
14477 builtin_declaration = true;
14478 bi_arg += builtin_type_decl(builtin);
14479 bi_arg += " " + builtin_to_glsl(builtin, StorageClassInput);
14480 bi_arg += " [[" + builtin_qualifier(builtin) + "]]";
14481 builtin_declaration = false;
14482
14483 return bi_arg;
14484 }
14485
get_physical_member_type(const SPIRType & type,uint32_t index) const14486 const SPIRType &CompilerMSL::get_physical_member_type(const SPIRType &type, uint32_t index) const
14487 {
14488 if (member_is_remapped_physical_type(type, index))
14489 return get<SPIRType>(get_extended_member_decoration(type.self, index, SPIRVCrossDecorationPhysicalTypeID));
14490 else
14491 return get<SPIRType>(type.member_types[index]);
14492 }
14493
get_presumed_input_type(const SPIRType & ib_type,uint32_t index) const14494 SPIRType CompilerMSL::get_presumed_input_type(const SPIRType &ib_type, uint32_t index) const
14495 {
14496 SPIRType type = get_physical_member_type(ib_type, index);
14497 uint32_t loc = get_member_decoration(ib_type.self, index, DecorationLocation);
14498 if (inputs_by_location.count(loc))
14499 {
14500 if (inputs_by_location.at(loc).vecsize > type.vecsize)
14501 type.vecsize = inputs_by_location.at(loc).vecsize;
14502 }
14503 return type;
14504 }
14505
get_declared_type_array_stride_msl(const SPIRType & type,bool is_packed,bool row_major) const14506 uint32_t CompilerMSL::get_declared_type_array_stride_msl(const SPIRType &type, bool is_packed, bool row_major) const
14507 {
14508 // Array stride in MSL is always size * array_size. sizeof(float3) == 16,
14509 // unlike GLSL and HLSL where array stride would be 16 and size 12.
14510
14511 // We could use parent type here and recurse, but that makes creating physical type remappings
14512 // far more complicated. We'd rather just create the final type, and ignore having to create the entire type
14513 // hierarchy in order to compute this value, so make a temporary type on the stack.
14514
14515 auto basic_type = type;
14516 basic_type.array.clear();
14517 basic_type.array_size_literal.clear();
14518 uint32_t value_size = get_declared_type_size_msl(basic_type, is_packed, row_major);
14519
14520 uint32_t dimensions = uint32_t(type.array.size());
14521 assert(dimensions > 0);
14522 dimensions--;
14523
14524 // Multiply together every dimension, except the last one.
14525 for (uint32_t dim = 0; dim < dimensions; dim++)
14526 {
14527 uint32_t array_size = to_array_size_literal(type, dim);
14528 value_size *= max(array_size, 1u);
14529 }
14530
14531 return value_size;
14532 }
14533
get_declared_struct_member_array_stride_msl(const SPIRType & type,uint32_t index) const14534 uint32_t CompilerMSL::get_declared_struct_member_array_stride_msl(const SPIRType &type, uint32_t index) const
14535 {
14536 return get_declared_type_array_stride_msl(get_physical_member_type(type, index),
14537 member_is_packed_physical_type(type, index),
14538 has_member_decoration(type.self, index, DecorationRowMajor));
14539 }
14540
get_declared_input_array_stride_msl(const SPIRType & type,uint32_t index) const14541 uint32_t CompilerMSL::get_declared_input_array_stride_msl(const SPIRType &type, uint32_t index) const
14542 {
14543 return get_declared_type_array_stride_msl(get_presumed_input_type(type, index), false,
14544 has_member_decoration(type.self, index, DecorationRowMajor));
14545 }
14546
get_declared_type_matrix_stride_msl(const SPIRType & type,bool packed,bool row_major) const14547 uint32_t CompilerMSL::get_declared_type_matrix_stride_msl(const SPIRType &type, bool packed, bool row_major) const
14548 {
14549 // For packed matrices, we just use the size of the vector type.
14550 // Otherwise, MatrixStride == alignment, which is the size of the underlying vector type.
14551 if (packed)
14552 return (type.width / 8) * ((row_major && type.columns > 1) ? type.columns : type.vecsize);
14553 else
14554 return get_declared_type_alignment_msl(type, false, row_major);
14555 }
14556
get_declared_struct_member_matrix_stride_msl(const SPIRType & type,uint32_t index) const14557 uint32_t CompilerMSL::get_declared_struct_member_matrix_stride_msl(const SPIRType &type, uint32_t index) const
14558 {
14559 return get_declared_type_matrix_stride_msl(get_physical_member_type(type, index),
14560 member_is_packed_physical_type(type, index),
14561 has_member_decoration(type.self, index, DecorationRowMajor));
14562 }
14563
get_declared_input_matrix_stride_msl(const SPIRType & type,uint32_t index) const14564 uint32_t CompilerMSL::get_declared_input_matrix_stride_msl(const SPIRType &type, uint32_t index) const
14565 {
14566 return get_declared_type_matrix_stride_msl(get_presumed_input_type(type, index), false,
14567 has_member_decoration(type.self, index, DecorationRowMajor));
14568 }
14569
get_declared_struct_size_msl(const SPIRType & struct_type,bool ignore_alignment,bool ignore_padding) const14570 uint32_t CompilerMSL::get_declared_struct_size_msl(const SPIRType &struct_type, bool ignore_alignment,
14571 bool ignore_padding) const
14572 {
14573 // If we have a target size, that is the declared size as well.
14574 if (!ignore_padding && has_extended_decoration(struct_type.self, SPIRVCrossDecorationPaddingTarget))
14575 return get_extended_decoration(struct_type.self, SPIRVCrossDecorationPaddingTarget);
14576
14577 if (struct_type.member_types.empty())
14578 return 0;
14579
14580 uint32_t mbr_cnt = uint32_t(struct_type.member_types.size());
14581
14582 // In MSL, a struct's alignment is equal to the maximum alignment of any of its members.
14583 uint32_t alignment = 1;
14584
14585 if (!ignore_alignment)
14586 {
14587 for (uint32_t i = 0; i < mbr_cnt; i++)
14588 {
14589 uint32_t mbr_alignment = get_declared_struct_member_alignment_msl(struct_type, i);
14590 alignment = max(alignment, mbr_alignment);
14591 }
14592 }
14593
14594 // Last member will always be matched to the final Offset decoration, but size of struct in MSL now depends
14595 // on physical size in MSL, and the size of the struct itself is then aligned to struct alignment.
14596 uint32_t spirv_offset = type_struct_member_offset(struct_type, mbr_cnt - 1);
14597 uint32_t msl_size = spirv_offset + get_declared_struct_member_size_msl(struct_type, mbr_cnt - 1);
14598 msl_size = (msl_size + alignment - 1) & ~(alignment - 1);
14599 return msl_size;
14600 }
14601
14602 // Returns the byte size of a struct member.
get_declared_type_size_msl(const SPIRType & type,bool is_packed,bool row_major) const14603 uint32_t CompilerMSL::get_declared_type_size_msl(const SPIRType &type, bool is_packed, bool row_major) const
14604 {
14605 switch (type.basetype)
14606 {
14607 case SPIRType::Unknown:
14608 case SPIRType::Void:
14609 case SPIRType::AtomicCounter:
14610 case SPIRType::Image:
14611 case SPIRType::SampledImage:
14612 case SPIRType::Sampler:
14613 SPIRV_CROSS_THROW("Querying size of opaque object.");
14614
14615 default:
14616 {
14617 if (!type.array.empty())
14618 {
14619 uint32_t array_size = to_array_size_literal(type);
14620 return get_declared_type_array_stride_msl(type, is_packed, row_major) * max(array_size, 1u);
14621 }
14622
14623 if (type.basetype == SPIRType::Struct)
14624 return get_declared_struct_size_msl(type);
14625
14626 if (is_packed)
14627 {
14628 return type.vecsize * type.columns * (type.width / 8);
14629 }
14630 else
14631 {
14632 // An unpacked 3-element vector or matrix column is the same memory size as a 4-element.
14633 uint32_t vecsize = type.vecsize;
14634 uint32_t columns = type.columns;
14635
14636 if (row_major && columns > 1)
14637 swap(vecsize, columns);
14638
14639 if (vecsize == 3)
14640 vecsize = 4;
14641
14642 return vecsize * columns * (type.width / 8);
14643 }
14644 }
14645 }
14646 }
14647
get_declared_struct_member_size_msl(const SPIRType & type,uint32_t index) const14648 uint32_t CompilerMSL::get_declared_struct_member_size_msl(const SPIRType &type, uint32_t index) const
14649 {
14650 return get_declared_type_size_msl(get_physical_member_type(type, index),
14651 member_is_packed_physical_type(type, index),
14652 has_member_decoration(type.self, index, DecorationRowMajor));
14653 }
14654
get_declared_input_size_msl(const SPIRType & type,uint32_t index) const14655 uint32_t CompilerMSL::get_declared_input_size_msl(const SPIRType &type, uint32_t index) const
14656 {
14657 return get_declared_type_size_msl(get_presumed_input_type(type, index), false,
14658 has_member_decoration(type.self, index, DecorationRowMajor));
14659 }
14660
14661 // Returns the byte alignment of a type.
get_declared_type_alignment_msl(const SPIRType & type,bool is_packed,bool row_major) const14662 uint32_t CompilerMSL::get_declared_type_alignment_msl(const SPIRType &type, bool is_packed, bool row_major) const
14663 {
14664 switch (type.basetype)
14665 {
14666 case SPIRType::Unknown:
14667 case SPIRType::Void:
14668 case SPIRType::AtomicCounter:
14669 case SPIRType::Image:
14670 case SPIRType::SampledImage:
14671 case SPIRType::Sampler:
14672 SPIRV_CROSS_THROW("Querying alignment of opaque object.");
14673
14674 case SPIRType::Double:
14675 SPIRV_CROSS_THROW("double types are not supported in buffers in MSL.");
14676
14677 case SPIRType::Struct:
14678 {
14679 // In MSL, a struct's alignment is equal to the maximum alignment of any of its members.
14680 uint32_t alignment = 1;
14681 for (uint32_t i = 0; i < type.member_types.size(); i++)
14682 alignment = max(alignment, uint32_t(get_declared_struct_member_alignment_msl(type, i)));
14683 return alignment;
14684 }
14685
14686 default:
14687 {
14688 if (type.basetype == SPIRType::Int64 && !msl_options.supports_msl_version(2, 3))
14689 SPIRV_CROSS_THROW("long types in buffers are only supported in MSL 2.3 and above.");
14690 if (type.basetype == SPIRType::UInt64 && !msl_options.supports_msl_version(2, 3))
14691 SPIRV_CROSS_THROW("ulong types in buffers are only supported in MSL 2.3 and above.");
14692 // Alignment of packed type is the same as the underlying component or column size.
14693 // Alignment of unpacked type is the same as the vector size.
14694 // Alignment of 3-elements vector is the same as 4-elements (including packed using column).
14695 if (is_packed)
14696 {
14697 // If we have packed_T and friends, the alignment is always scalar.
14698 return type.width / 8;
14699 }
14700 else
14701 {
14702 // This is the general rule for MSL. Size == alignment.
14703 uint32_t vecsize = (row_major && type.columns > 1) ? type.columns : type.vecsize;
14704 return (type.width / 8) * (vecsize == 3 ? 4 : vecsize);
14705 }
14706 }
14707 }
14708 }
14709
get_declared_struct_member_alignment_msl(const SPIRType & type,uint32_t index) const14710 uint32_t CompilerMSL::get_declared_struct_member_alignment_msl(const SPIRType &type, uint32_t index) const
14711 {
14712 return get_declared_type_alignment_msl(get_physical_member_type(type, index),
14713 member_is_packed_physical_type(type, index),
14714 has_member_decoration(type.self, index, DecorationRowMajor));
14715 }
14716
get_declared_input_alignment_msl(const SPIRType & type,uint32_t index) const14717 uint32_t CompilerMSL::get_declared_input_alignment_msl(const SPIRType &type, uint32_t index) const
14718 {
14719 return get_declared_type_alignment_msl(get_presumed_input_type(type, index), false,
14720 has_member_decoration(type.self, index, DecorationRowMajor));
14721 }
14722
skip_argument(uint32_t) const14723 bool CompilerMSL::skip_argument(uint32_t) const
14724 {
14725 return false;
14726 }
14727
analyze_sampled_image_usage()14728 void CompilerMSL::analyze_sampled_image_usage()
14729 {
14730 if (msl_options.swizzle_texture_samples)
14731 {
14732 SampledImageScanner scanner(*this);
14733 traverse_all_reachable_opcodes(get<SPIRFunction>(ir.default_entry_point), scanner);
14734 }
14735 }
14736
handle(spv::Op opcode,const uint32_t * args,uint32_t length)14737 bool CompilerMSL::SampledImageScanner::handle(spv::Op opcode, const uint32_t *args, uint32_t length)
14738 {
14739 switch (opcode)
14740 {
14741 case OpLoad:
14742 case OpImage:
14743 case OpSampledImage:
14744 {
14745 if (length < 3)
14746 return false;
14747
14748 uint32_t result_type = args[0];
14749 auto &type = compiler.get<SPIRType>(result_type);
14750 if ((type.basetype != SPIRType::Image && type.basetype != SPIRType::SampledImage) || type.image.sampled != 1)
14751 return true;
14752
14753 uint32_t id = args[1];
14754 compiler.set<SPIRExpression>(id, "", result_type, true);
14755 break;
14756 }
14757 case OpImageSampleExplicitLod:
14758 case OpImageSampleProjExplicitLod:
14759 case OpImageSampleDrefExplicitLod:
14760 case OpImageSampleProjDrefExplicitLod:
14761 case OpImageSampleImplicitLod:
14762 case OpImageSampleProjImplicitLod:
14763 case OpImageSampleDrefImplicitLod:
14764 case OpImageSampleProjDrefImplicitLod:
14765 case OpImageFetch:
14766 case OpImageGather:
14767 case OpImageDrefGather:
14768 compiler.has_sampled_images =
14769 compiler.has_sampled_images || compiler.is_sampled_image_type(compiler.expression_type(args[2]));
14770 compiler.needs_swizzle_buffer_def = compiler.needs_swizzle_buffer_def || compiler.has_sampled_images;
14771 break;
14772 default:
14773 break;
14774 }
14775 return true;
14776 }
14777
14778 // If a needed custom function wasn't added before, add it and force a recompile.
add_spv_func_and_recompile(SPVFuncImpl spv_func)14779 void CompilerMSL::add_spv_func_and_recompile(SPVFuncImpl spv_func)
14780 {
14781 if (spv_function_implementations.count(spv_func) == 0)
14782 {
14783 spv_function_implementations.insert(spv_func);
14784 suppress_missing_prototypes = true;
14785 force_recompile();
14786 }
14787 }
14788
handle(Op opcode,const uint32_t * args,uint32_t length)14789 bool CompilerMSL::OpCodePreprocessor::handle(Op opcode, const uint32_t *args, uint32_t length)
14790 {
14791 // Since MSL exists in a single execution scope, function prototype declarations are not
14792 // needed, and clutter the output. If secondary functions are output (either as a SPIR-V
14793 // function implementation or as indicated by the presence of OpFunctionCall), then set
14794 // suppress_missing_prototypes to suppress compiler warnings of missing function prototypes.
14795
14796 // Mark if the input requires the implementation of an SPIR-V function that does not exist in Metal.
14797 SPVFuncImpl spv_func = get_spv_func_impl(opcode, args);
14798 if (spv_func != SPVFuncImplNone)
14799 {
14800 compiler.spv_function_implementations.insert(spv_func);
14801 suppress_missing_prototypes = true;
14802 }
14803
14804 switch (opcode)
14805 {
14806
14807 case OpFunctionCall:
14808 suppress_missing_prototypes = true;
14809 break;
14810
14811 // Emulate texture2D atomic operations
14812 case OpImageTexelPointer:
14813 {
14814 auto *var = compiler.maybe_get_backing_variable(args[2]);
14815 image_pointers[args[1]] = var ? var->self : ID(0);
14816 break;
14817 }
14818
14819 case OpImageWrite:
14820 if (!compiler.msl_options.supports_msl_version(2, 2))
14821 uses_resource_write = true;
14822 break;
14823
14824 case OpStore:
14825 check_resource_write(args[0]);
14826 break;
14827
14828 // Emulate texture2D atomic operations
14829 case OpAtomicExchange:
14830 case OpAtomicCompareExchange:
14831 case OpAtomicCompareExchangeWeak:
14832 case OpAtomicIIncrement:
14833 case OpAtomicIDecrement:
14834 case OpAtomicIAdd:
14835 case OpAtomicISub:
14836 case OpAtomicSMin:
14837 case OpAtomicUMin:
14838 case OpAtomicSMax:
14839 case OpAtomicUMax:
14840 case OpAtomicAnd:
14841 case OpAtomicOr:
14842 case OpAtomicXor:
14843 {
14844 uses_atomics = true;
14845 auto it = image_pointers.find(args[2]);
14846 if (it != image_pointers.end())
14847 {
14848 compiler.atomic_image_vars.insert(it->second);
14849 }
14850 check_resource_write(args[2]);
14851 break;
14852 }
14853
14854 case OpAtomicStore:
14855 {
14856 uses_atomics = true;
14857 auto it = image_pointers.find(args[0]);
14858 if (it != image_pointers.end())
14859 {
14860 compiler.atomic_image_vars.insert(it->second);
14861 }
14862 check_resource_write(args[0]);
14863 break;
14864 }
14865
14866 case OpAtomicLoad:
14867 {
14868 uses_atomics = true;
14869 auto it = image_pointers.find(args[2]);
14870 if (it != image_pointers.end())
14871 {
14872 compiler.atomic_image_vars.insert(it->second);
14873 }
14874 break;
14875 }
14876
14877 case OpGroupNonUniformInverseBallot:
14878 needs_subgroup_invocation_id = true;
14879 break;
14880
14881 case OpGroupNonUniformBallotFindLSB:
14882 case OpGroupNonUniformBallotFindMSB:
14883 needs_subgroup_size = true;
14884 break;
14885
14886 case OpGroupNonUniformBallotBitCount:
14887 if (args[3] == GroupOperationReduce)
14888 needs_subgroup_size = true;
14889 else
14890 needs_subgroup_invocation_id = true;
14891 break;
14892
14893 case OpArrayLength:
14894 {
14895 auto *var = compiler.maybe_get_backing_variable(args[2]);
14896 if (var)
14897 compiler.buffers_requiring_array_length.insert(var->self);
14898 break;
14899 }
14900
14901 case OpInBoundsAccessChain:
14902 case OpAccessChain:
14903 case OpPtrAccessChain:
14904 {
14905 // OpArrayLength might want to know if taking ArrayLength of an array of SSBOs.
14906 uint32_t result_type = args[0];
14907 uint32_t id = args[1];
14908 uint32_t ptr = args[2];
14909
14910 compiler.set<SPIRExpression>(id, "", result_type, true);
14911 compiler.register_read(id, ptr, true);
14912 compiler.ir.ids[id].set_allow_type_rewrite();
14913 break;
14914 }
14915
14916 case OpExtInst:
14917 {
14918 uint32_t extension_set = args[2];
14919 if (compiler.get<SPIRExtension>(extension_set).ext == SPIRExtension::GLSL)
14920 {
14921 auto op_450 = static_cast<GLSLstd450>(args[3]);
14922 switch (op_450)
14923 {
14924 case GLSLstd450InterpolateAtCentroid:
14925 case GLSLstd450InterpolateAtSample:
14926 case GLSLstd450InterpolateAtOffset:
14927 {
14928 if (!compiler.msl_options.supports_msl_version(2, 3))
14929 SPIRV_CROSS_THROW("Pull-model interpolation requires MSL 2.3.");
14930 // Fragment varyings used with pull-model interpolation need special handling,
14931 // due to the way pull-model interpolation works in Metal.
14932 auto *var = compiler.maybe_get_backing_variable(args[4]);
14933 if (var)
14934 {
14935 compiler.pull_model_inputs.insert(var->self);
14936 auto &var_type = compiler.get_variable_element_type(*var);
14937 // In addition, if this variable has a 'Sample' decoration, we need the sample ID
14938 // in order to do default interpolation.
14939 if (compiler.has_decoration(var->self, DecorationSample))
14940 {
14941 needs_sample_id = true;
14942 }
14943 else if (var_type.basetype == SPIRType::Struct)
14944 {
14945 // Now we need to check each member and see if it has this decoration.
14946 for (uint32_t i = 0; i < var_type.member_types.size(); ++i)
14947 {
14948 if (compiler.has_member_decoration(var_type.self, i, DecorationSample))
14949 {
14950 needs_sample_id = true;
14951 break;
14952 }
14953 }
14954 }
14955 }
14956 break;
14957 }
14958 default:
14959 break;
14960 }
14961 }
14962 break;
14963 }
14964
14965 default:
14966 break;
14967 }
14968
14969 // If it has one, keep track of the instruction's result type, mapped by ID
14970 uint32_t result_type, result_id;
14971 if (compiler.instruction_to_result_type(result_type, result_id, opcode, args, length))
14972 result_types[result_id] = result_type;
14973
14974 return true;
14975 }
14976
14977 // If the variable is a Uniform or StorageBuffer, mark that a resource has been written to.
check_resource_write(uint32_t var_id)14978 void CompilerMSL::OpCodePreprocessor::check_resource_write(uint32_t var_id)
14979 {
14980 auto *p_var = compiler.maybe_get_backing_variable(var_id);
14981 StorageClass sc = p_var ? p_var->storage : StorageClassMax;
14982 if (!compiler.msl_options.supports_msl_version(2, 1) &&
14983 (sc == StorageClassUniform || sc == StorageClassStorageBuffer))
14984 uses_resource_write = true;
14985 }
14986
14987 // Returns an enumeration of a SPIR-V function that needs to be output for certain Op codes.
get_spv_func_impl(Op opcode,const uint32_t * args)14988 CompilerMSL::SPVFuncImpl CompilerMSL::OpCodePreprocessor::get_spv_func_impl(Op opcode, const uint32_t *args)
14989 {
14990 switch (opcode)
14991 {
14992 case OpFMod:
14993 return SPVFuncImplMod;
14994
14995 case OpFAdd:
14996 case OpFSub:
14997 if (compiler.msl_options.invariant_float_math ||
14998 compiler.has_decoration(args[1], DecorationNoContraction))
14999 {
15000 return opcode == OpFAdd ? SPVFuncImplFAdd : SPVFuncImplFSub;
15001 }
15002 break;
15003
15004 case OpFMul:
15005 case OpOuterProduct:
15006 case OpMatrixTimesVector:
15007 case OpVectorTimesMatrix:
15008 case OpMatrixTimesMatrix:
15009 if (compiler.msl_options.invariant_float_math ||
15010 compiler.has_decoration(args[1], DecorationNoContraction))
15011 {
15012 return SPVFuncImplFMul;
15013 }
15014 break;
15015
15016 case OpTypeArray:
15017 {
15018 // Allow Metal to use the array<T> template to make arrays a value type
15019 return SPVFuncImplUnsafeArray;
15020 }
15021
15022 // Emulate texture2D atomic operations
15023 case OpAtomicExchange:
15024 case OpAtomicCompareExchange:
15025 case OpAtomicCompareExchangeWeak:
15026 case OpAtomicIIncrement:
15027 case OpAtomicIDecrement:
15028 case OpAtomicIAdd:
15029 case OpAtomicISub:
15030 case OpAtomicSMin:
15031 case OpAtomicUMin:
15032 case OpAtomicSMax:
15033 case OpAtomicUMax:
15034 case OpAtomicAnd:
15035 case OpAtomicOr:
15036 case OpAtomicXor:
15037 case OpAtomicLoad:
15038 case OpAtomicStore:
15039 {
15040 auto it = image_pointers.find(args[opcode == OpAtomicStore ? 0 : 2]);
15041 if (it != image_pointers.end())
15042 {
15043 uint32_t tid = compiler.get<SPIRVariable>(it->second).basetype;
15044 if (tid && compiler.get<SPIRType>(tid).image.dim == Dim2D)
15045 return SPVFuncImplImage2DAtomicCoords;
15046 }
15047 break;
15048 }
15049
15050 case OpImageFetch:
15051 case OpImageRead:
15052 case OpImageWrite:
15053 {
15054 // Retrieve the image type, and if it's a Buffer, emit a texel coordinate function
15055 uint32_t tid = result_types[args[opcode == OpImageWrite ? 0 : 2]];
15056 if (tid && compiler.get<SPIRType>(tid).image.dim == DimBuffer && !compiler.msl_options.texture_buffer_native)
15057 return SPVFuncImplTexelBufferCoords;
15058 break;
15059 }
15060
15061 case OpExtInst:
15062 {
15063 uint32_t extension_set = args[2];
15064 if (compiler.get<SPIRExtension>(extension_set).ext == SPIRExtension::GLSL)
15065 {
15066 auto op_450 = static_cast<GLSLstd450>(args[3]);
15067 switch (op_450)
15068 {
15069 case GLSLstd450Radians:
15070 return SPVFuncImplRadians;
15071 case GLSLstd450Degrees:
15072 return SPVFuncImplDegrees;
15073 case GLSLstd450FindILsb:
15074 return SPVFuncImplFindILsb;
15075 case GLSLstd450FindSMsb:
15076 return SPVFuncImplFindSMsb;
15077 case GLSLstd450FindUMsb:
15078 return SPVFuncImplFindUMsb;
15079 case GLSLstd450SSign:
15080 return SPVFuncImplSSign;
15081 case GLSLstd450Reflect:
15082 {
15083 auto &type = compiler.get<SPIRType>(args[0]);
15084 if (type.vecsize == 1)
15085 return SPVFuncImplReflectScalar;
15086 break;
15087 }
15088 case GLSLstd450Refract:
15089 {
15090 auto &type = compiler.get<SPIRType>(args[0]);
15091 if (type.vecsize == 1)
15092 return SPVFuncImplRefractScalar;
15093 break;
15094 }
15095 case GLSLstd450FaceForward:
15096 {
15097 auto &type = compiler.get<SPIRType>(args[0]);
15098 if (type.vecsize == 1)
15099 return SPVFuncImplFaceForwardScalar;
15100 break;
15101 }
15102 case GLSLstd450MatrixInverse:
15103 {
15104 auto &mat_type = compiler.get<SPIRType>(args[0]);
15105 switch (mat_type.columns)
15106 {
15107 case 2:
15108 return SPVFuncImplInverse2x2;
15109 case 3:
15110 return SPVFuncImplInverse3x3;
15111 case 4:
15112 return SPVFuncImplInverse4x4;
15113 default:
15114 break;
15115 }
15116 break;
15117 }
15118 default:
15119 break;
15120 }
15121 }
15122 break;
15123 }
15124
15125 case OpGroupNonUniformBroadcast:
15126 return SPVFuncImplSubgroupBroadcast;
15127
15128 case OpGroupNonUniformBroadcastFirst:
15129 return SPVFuncImplSubgroupBroadcastFirst;
15130
15131 case OpGroupNonUniformBallot:
15132 return SPVFuncImplSubgroupBallot;
15133
15134 case OpGroupNonUniformInverseBallot:
15135 case OpGroupNonUniformBallotBitExtract:
15136 return SPVFuncImplSubgroupBallotBitExtract;
15137
15138 case OpGroupNonUniformBallotFindLSB:
15139 return SPVFuncImplSubgroupBallotFindLSB;
15140
15141 case OpGroupNonUniformBallotFindMSB:
15142 return SPVFuncImplSubgroupBallotFindMSB;
15143
15144 case OpGroupNonUniformBallotBitCount:
15145 return SPVFuncImplSubgroupBallotBitCount;
15146
15147 case OpGroupNonUniformAllEqual:
15148 return SPVFuncImplSubgroupAllEqual;
15149
15150 case OpGroupNonUniformShuffle:
15151 return SPVFuncImplSubgroupShuffle;
15152
15153 case OpGroupNonUniformShuffleXor:
15154 return SPVFuncImplSubgroupShuffleXor;
15155
15156 case OpGroupNonUniformShuffleUp:
15157 return SPVFuncImplSubgroupShuffleUp;
15158
15159 case OpGroupNonUniformShuffleDown:
15160 return SPVFuncImplSubgroupShuffleDown;
15161
15162 case OpGroupNonUniformQuadBroadcast:
15163 return SPVFuncImplQuadBroadcast;
15164
15165 case OpGroupNonUniformQuadSwap:
15166 return SPVFuncImplQuadSwap;
15167
15168 default:
15169 break;
15170 }
15171 return SPVFuncImplNone;
15172 }
15173
15174 // Sort both type and meta member content based on builtin status (put builtins at end),
15175 // then by the required sorting aspect.
sort()15176 void CompilerMSL::MemberSorter::sort()
15177 {
15178 // Create a temporary array of consecutive member indices and sort it based on how
15179 // the members should be reordered, based on builtin and sorting aspect meta info.
15180 size_t mbr_cnt = type.member_types.size();
15181 SmallVector<uint32_t> mbr_idxs(mbr_cnt);
15182 std::iota(mbr_idxs.begin(), mbr_idxs.end(), 0); // Fill with consecutive indices
15183 std::stable_sort(mbr_idxs.begin(), mbr_idxs.end(), *this); // Sort member indices based on sorting aspect
15184
15185 bool sort_is_identity = true;
15186 for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
15187 {
15188 if (mbr_idx != mbr_idxs[mbr_idx])
15189 {
15190 sort_is_identity = false;
15191 break;
15192 }
15193 }
15194
15195 if (sort_is_identity)
15196 return;
15197
15198 if (meta.members.size() < type.member_types.size())
15199 {
15200 // This should never trigger in normal circumstances, but to be safe.
15201 meta.members.resize(type.member_types.size());
15202 }
15203
15204 // Move type and meta member info to the order defined by the sorted member indices.
15205 // This is done by creating temporary copies of both member types and meta, and then
15206 // copying back to the original content at the sorted indices.
15207 auto mbr_types_cpy = type.member_types;
15208 auto mbr_meta_cpy = meta.members;
15209 for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
15210 {
15211 type.member_types[mbr_idx] = mbr_types_cpy[mbr_idxs[mbr_idx]];
15212 meta.members[mbr_idx] = mbr_meta_cpy[mbr_idxs[mbr_idx]];
15213 }
15214
15215 if (sort_aspect == SortAspect::Offset)
15216 {
15217 // If we're sorting by Offset, this might affect user code which accesses a buffer block.
15218 // We will need to redirect member indices from one index to sorted index.
15219 type.member_type_index_redirection = std::move(mbr_idxs);
15220 }
15221 }
15222
operator ()(uint32_t mbr_idx1,uint32_t mbr_idx2)15223 bool CompilerMSL::MemberSorter::operator()(uint32_t mbr_idx1, uint32_t mbr_idx2)
15224 {
15225 auto &mbr_meta1 = meta.members[mbr_idx1];
15226 auto &mbr_meta2 = meta.members[mbr_idx2];
15227
15228 if (sort_aspect == LocationThenBuiltInType)
15229 {
15230 // Sort first by builtin status (put builtins at end), then by the sorting aspect.
15231 if (mbr_meta1.builtin != mbr_meta2.builtin)
15232 return mbr_meta2.builtin;
15233 else if (mbr_meta1.builtin)
15234 return mbr_meta1.builtin_type < mbr_meta2.builtin_type;
15235 else if (mbr_meta1.location == mbr_meta2.location)
15236 return mbr_meta1.component < mbr_meta2.component;
15237 else
15238 return mbr_meta1.location < mbr_meta2.location;
15239 }
15240 else
15241 return mbr_meta1.offset < mbr_meta2.offset;
15242 }
15243
MemberSorter(SPIRType & t,Meta & m,SortAspect sa)15244 CompilerMSL::MemberSorter::MemberSorter(SPIRType &t, Meta &m, SortAspect sa)
15245 : type(t)
15246 , meta(m)
15247 , sort_aspect(sa)
15248 {
15249 // Ensure enough meta info is available
15250 meta.members.resize(max(type.member_types.size(), meta.members.size()));
15251 }
15252
remap_constexpr_sampler(VariableID id,const MSLConstexprSampler & sampler)15253 void CompilerMSL::remap_constexpr_sampler(VariableID id, const MSLConstexprSampler &sampler)
15254 {
15255 auto &type = get<SPIRType>(get<SPIRVariable>(id).basetype);
15256 if (type.basetype != SPIRType::SampledImage && type.basetype != SPIRType::Sampler)
15257 SPIRV_CROSS_THROW("Can only remap SampledImage and Sampler type.");
15258 if (!type.array.empty())
15259 SPIRV_CROSS_THROW("Can not remap array of samplers.");
15260 constexpr_samplers_by_id[id] = sampler;
15261 }
15262
remap_constexpr_sampler_by_binding(uint32_t desc_set,uint32_t binding,const MSLConstexprSampler & sampler)15263 void CompilerMSL::remap_constexpr_sampler_by_binding(uint32_t desc_set, uint32_t binding,
15264 const MSLConstexprSampler &sampler)
15265 {
15266 constexpr_samplers_by_binding[{ desc_set, binding }] = sampler;
15267 }
15268
cast_from_builtin_load(uint32_t source_id,std::string & expr,const SPIRType & expr_type)15269 void CompilerMSL::cast_from_builtin_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type)
15270 {
15271 auto *var = maybe_get_backing_variable(source_id);
15272 if (var)
15273 source_id = var->self;
15274
15275 // Only interested in standalone builtin variables.
15276 if (!has_decoration(source_id, DecorationBuiltIn))
15277 return;
15278
15279 auto builtin = static_cast<BuiltIn>(get_decoration(source_id, DecorationBuiltIn));
15280 auto expected_type = expr_type.basetype;
15281 auto expected_width = expr_type.width;
15282 switch (builtin)
15283 {
15284 case BuiltInGlobalInvocationId:
15285 case BuiltInLocalInvocationId:
15286 case BuiltInWorkgroupId:
15287 case BuiltInLocalInvocationIndex:
15288 case BuiltInWorkgroupSize:
15289 case BuiltInNumWorkgroups:
15290 case BuiltInLayer:
15291 case BuiltInViewportIndex:
15292 case BuiltInFragStencilRefEXT:
15293 case BuiltInPrimitiveId:
15294 case BuiltInSubgroupSize:
15295 case BuiltInSubgroupLocalInvocationId:
15296 case BuiltInViewIndex:
15297 case BuiltInVertexIndex:
15298 case BuiltInInstanceIndex:
15299 case BuiltInBaseInstance:
15300 case BuiltInBaseVertex:
15301 expected_type = SPIRType::UInt;
15302 expected_width = 32;
15303 break;
15304
15305 case BuiltInTessLevelInner:
15306 case BuiltInTessLevelOuter:
15307 if (get_execution_model() == ExecutionModelTessellationControl)
15308 {
15309 expected_type = SPIRType::Half;
15310 expected_width = 16;
15311 }
15312 break;
15313
15314 default:
15315 break;
15316 }
15317
15318 if (expected_type != expr_type.basetype)
15319 {
15320 if (!expr_type.array.empty() && (builtin == BuiltInTessLevelInner || builtin == BuiltInTessLevelOuter))
15321 {
15322 // Triggers when loading TessLevel directly as an array.
15323 // Need explicit padding + cast.
15324 auto wrap_expr = join(type_to_glsl(expr_type), "({ ");
15325
15326 uint32_t array_size = get_physical_tess_level_array_size(builtin);
15327 for (uint32_t i = 0; i < array_size; i++)
15328 {
15329 if (array_size > 1)
15330 wrap_expr += join("float(", expr, "[", i, "])");
15331 else
15332 wrap_expr += join("float(", expr, ")");
15333 if (i + 1 < array_size)
15334 wrap_expr += ", ";
15335 }
15336
15337 if (get_execution_mode_bitset().get(ExecutionModeTriangles))
15338 wrap_expr += ", 0.0";
15339
15340 wrap_expr += " })";
15341 expr = std::move(wrap_expr);
15342 }
15343 else
15344 {
15345 // These are of different widths, so we cannot do a straight bitcast.
15346 if (expected_width != expr_type.width)
15347 expr = join(type_to_glsl(expr_type), "(", expr, ")");
15348 else
15349 expr = bitcast_expression(expr_type, expected_type, expr);
15350 }
15351 }
15352
15353 if (builtin == BuiltInTessCoord && get_entry_point().flags.get(ExecutionModeQuads) && expr_type.vecsize == 3)
15354 {
15355 // In SPIR-V, this is always a vec3, even for quads. In Metal, though, it's a float2 for quads.
15356 // The code is expecting a float3, so we need to widen this.
15357 expr = join("float3(", expr, ", 0)");
15358 }
15359 }
15360
cast_to_builtin_store(uint32_t target_id,std::string & expr,const SPIRType & expr_type)15361 void CompilerMSL::cast_to_builtin_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type)
15362 {
15363 auto *var = maybe_get_backing_variable(target_id);
15364 if (var)
15365 target_id = var->self;
15366
15367 // Only interested in standalone builtin variables.
15368 if (!has_decoration(target_id, DecorationBuiltIn))
15369 return;
15370
15371 auto builtin = static_cast<BuiltIn>(get_decoration(target_id, DecorationBuiltIn));
15372 auto expected_type = expr_type.basetype;
15373 auto expected_width = expr_type.width;
15374 switch (builtin)
15375 {
15376 case BuiltInLayer:
15377 case BuiltInViewportIndex:
15378 case BuiltInFragStencilRefEXT:
15379 case BuiltInPrimitiveId:
15380 case BuiltInViewIndex:
15381 expected_type = SPIRType::UInt;
15382 expected_width = 32;
15383 break;
15384
15385 case BuiltInTessLevelInner:
15386 case BuiltInTessLevelOuter:
15387 expected_type = SPIRType::Half;
15388 expected_width = 16;
15389 break;
15390
15391 default:
15392 break;
15393 }
15394
15395 if (expected_type != expr_type.basetype)
15396 {
15397 if (expected_width != expr_type.width)
15398 {
15399 // These are of different widths, so we cannot do a straight bitcast.
15400 auto type = expr_type;
15401 type.basetype = expected_type;
15402 type.width = expected_width;
15403 expr = join(type_to_glsl(type), "(", expr, ")");
15404 }
15405 else
15406 {
15407 auto type = expr_type;
15408 type.basetype = expected_type;
15409 expr = bitcast_expression(type, expr_type.basetype, expr);
15410 }
15411 }
15412 }
15413
to_initializer_expression(const SPIRVariable & var)15414 string CompilerMSL::to_initializer_expression(const SPIRVariable &var)
15415 {
15416 // We risk getting an array initializer here with MSL. If we have an array.
15417 // FIXME: We cannot handle non-constant arrays being initialized.
15418 // We will need to inject spvArrayCopy here somehow ...
15419 auto &type = get<SPIRType>(var.basetype);
15420 string expr;
15421 if (ir.ids[var.initializer].get_type() == TypeConstant &&
15422 (!type.array.empty() || type.basetype == SPIRType::Struct))
15423 expr = constant_expression(get<SPIRConstant>(var.initializer));
15424 else
15425 expr = CompilerGLSL::to_initializer_expression(var);
15426 // If the initializer has more vector components than the variable, add a swizzle.
15427 // FIXME: This can't handle arrays or structs.
15428 auto &init_type = expression_type(var.initializer);
15429 if (type.array.empty() && type.basetype != SPIRType::Struct && init_type.vecsize > type.vecsize)
15430 expr = enclose_expression(expr + vector_swizzle(type.vecsize, 0));
15431 return expr;
15432 }
15433
to_zero_initialized_expression(uint32_t)15434 string CompilerMSL::to_zero_initialized_expression(uint32_t)
15435 {
15436 return "{}";
15437 }
15438
descriptor_set_is_argument_buffer(uint32_t desc_set) const15439 bool CompilerMSL::descriptor_set_is_argument_buffer(uint32_t desc_set) const
15440 {
15441 if (!msl_options.argument_buffers)
15442 return false;
15443 if (desc_set >= kMaxArgumentBuffers)
15444 return false;
15445
15446 return (argument_buffer_discrete_mask & (1u << desc_set)) == 0;
15447 }
15448
is_supported_argument_buffer_type(const SPIRType & type) const15449 bool CompilerMSL::is_supported_argument_buffer_type(const SPIRType &type) const
15450 {
15451 // Very specifically, image load-store in argument buffers are disallowed on MSL on iOS.
15452 // But we won't know when the argument buffer is encoded whether this image will have
15453 // a NonWritable decoration. So just use discrete arguments for all storage images
15454 // on iOS.
15455 bool is_storage_image = type.basetype == SPIRType::Image && type.image.sampled == 2;
15456 bool is_supported_type = !msl_options.is_ios() || !is_storage_image;
15457 return !type_is_msl_framebuffer_fetch(type) && is_supported_type;
15458 }
15459
analyze_argument_buffers()15460 void CompilerMSL::analyze_argument_buffers()
15461 {
15462 // Gather all used resources and sort them out into argument buffers.
15463 // Each argument buffer corresponds to a descriptor set in SPIR-V.
15464 // The [[id(N)]] values used correspond to the resource mapping we have for MSL.
15465 // Otherwise, the binding number is used, but this is generally not safe some types like
15466 // combined image samplers and arrays of resources. Metal needs different indices here,
15467 // while SPIR-V can have one descriptor set binding. To use argument buffers in practice,
15468 // you will need to use the remapping from the API.
15469 for (auto &id : argument_buffer_ids)
15470 id = 0;
15471
15472 // Output resources, sorted by resource index & type.
15473 struct Resource
15474 {
15475 SPIRVariable *var;
15476 string name;
15477 SPIRType::BaseType basetype;
15478 uint32_t index;
15479 uint32_t plane;
15480 };
15481 SmallVector<Resource> resources_in_set[kMaxArgumentBuffers];
15482 SmallVector<uint32_t> inline_block_vars;
15483
15484 bool set_needs_swizzle_buffer[kMaxArgumentBuffers] = {};
15485 bool set_needs_buffer_sizes[kMaxArgumentBuffers] = {};
15486 bool needs_buffer_sizes = false;
15487
15488 ir.for_each_typed_id<SPIRVariable>([&](uint32_t self, SPIRVariable &var) {
15489 if ((var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant ||
15490 var.storage == StorageClassStorageBuffer) &&
15491 !is_hidden_variable(var))
15492 {
15493 uint32_t desc_set = get_decoration(self, DecorationDescriptorSet);
15494 // Ignore if it's part of a push descriptor set.
15495 if (!descriptor_set_is_argument_buffer(desc_set))
15496 return;
15497
15498 uint32_t var_id = var.self;
15499 auto &type = get_variable_data_type(var);
15500
15501 if (desc_set >= kMaxArgumentBuffers)
15502 SPIRV_CROSS_THROW("Descriptor set index is out of range.");
15503
15504 const MSLConstexprSampler *constexpr_sampler = nullptr;
15505 if (type.basetype == SPIRType::SampledImage || type.basetype == SPIRType::Sampler)
15506 {
15507 constexpr_sampler = find_constexpr_sampler(var_id);
15508 if (constexpr_sampler)
15509 {
15510 // Mark this ID as a constexpr sampler for later in case it came from set/bindings.
15511 constexpr_samplers_by_id[var_id] = *constexpr_sampler;
15512 }
15513 }
15514
15515 uint32_t binding = get_decoration(var_id, DecorationBinding);
15516 if (type.basetype == SPIRType::SampledImage)
15517 {
15518 add_resource_name(var_id);
15519
15520 uint32_t plane_count = 1;
15521 if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
15522 plane_count = constexpr_sampler->planes;
15523
15524 for (uint32_t i = 0; i < plane_count; i++)
15525 {
15526 uint32_t image_resource_index = get_metal_resource_index(var, SPIRType::Image, i);
15527 resources_in_set[desc_set].push_back(
15528 { &var, to_name(var_id), SPIRType::Image, image_resource_index, i });
15529 }
15530
15531 if (type.image.dim != DimBuffer && !constexpr_sampler)
15532 {
15533 uint32_t sampler_resource_index = get_metal_resource_index(var, SPIRType::Sampler);
15534 resources_in_set[desc_set].push_back(
15535 { &var, to_sampler_expression(var_id), SPIRType::Sampler, sampler_resource_index, 0 });
15536 }
15537 }
15538 else if (inline_uniform_blocks.count(SetBindingPair{ desc_set, binding }))
15539 {
15540 inline_block_vars.push_back(var_id);
15541 }
15542 else if (!constexpr_sampler && is_supported_argument_buffer_type(type))
15543 {
15544 // constexpr samplers are not declared as resources.
15545 // Inline uniform blocks are always emitted at the end.
15546 add_resource_name(var_id);
15547 resources_in_set[desc_set].push_back(
15548 { &var, to_name(var_id), type.basetype, get_metal_resource_index(var, type.basetype), 0 });
15549
15550 // Emulate texture2D atomic operations
15551 if (atomic_image_vars.count(var.self))
15552 {
15553 uint32_t buffer_resource_index = get_metal_resource_index(var, SPIRType::AtomicCounter, 0);
15554 resources_in_set[desc_set].push_back(
15555 { &var, to_name(var_id) + "_atomic", SPIRType::Struct, buffer_resource_index, 0 });
15556 }
15557 }
15558
15559 // Check if this descriptor set needs a swizzle buffer.
15560 if (needs_swizzle_buffer_def && is_sampled_image_type(type))
15561 set_needs_swizzle_buffer[desc_set] = true;
15562 else if (buffers_requiring_array_length.count(var_id) != 0)
15563 {
15564 set_needs_buffer_sizes[desc_set] = true;
15565 needs_buffer_sizes = true;
15566 }
15567 }
15568 });
15569
15570 if (needs_swizzle_buffer_def || needs_buffer_sizes)
15571 {
15572 uint32_t uint_ptr_type_id = 0;
15573
15574 // We might have to add a swizzle buffer resource to the set.
15575 for (uint32_t desc_set = 0; desc_set < kMaxArgumentBuffers; desc_set++)
15576 {
15577 if (!set_needs_swizzle_buffer[desc_set] && !set_needs_buffer_sizes[desc_set])
15578 continue;
15579
15580 if (uint_ptr_type_id == 0)
15581 {
15582 uint_ptr_type_id = ir.increase_bound_by(1);
15583
15584 // Create a buffer to hold extra data, including the swizzle constants.
15585 SPIRType uint_type_pointer = get_uint_type();
15586 uint_type_pointer.pointer = true;
15587 uint_type_pointer.pointer_depth++;
15588 uint_type_pointer.parent_type = get_uint_type_id();
15589 uint_type_pointer.storage = StorageClassUniform;
15590 set<SPIRType>(uint_ptr_type_id, uint_type_pointer);
15591 set_decoration(uint_ptr_type_id, DecorationArrayStride, 4);
15592 }
15593
15594 if (set_needs_swizzle_buffer[desc_set])
15595 {
15596 uint32_t var_id = ir.increase_bound_by(1);
15597 auto &var = set<SPIRVariable>(var_id, uint_ptr_type_id, StorageClassUniformConstant);
15598 set_name(var_id, "spvSwizzleConstants");
15599 set_decoration(var_id, DecorationDescriptorSet, desc_set);
15600 set_decoration(var_id, DecorationBinding, kSwizzleBufferBinding);
15601 resources_in_set[desc_set].push_back(
15602 { &var, to_name(var_id), SPIRType::UInt, get_metal_resource_index(var, SPIRType::UInt), 0 });
15603 }
15604
15605 if (set_needs_buffer_sizes[desc_set])
15606 {
15607 uint32_t var_id = ir.increase_bound_by(1);
15608 auto &var = set<SPIRVariable>(var_id, uint_ptr_type_id, StorageClassUniformConstant);
15609 set_name(var_id, "spvBufferSizeConstants");
15610 set_decoration(var_id, DecorationDescriptorSet, desc_set);
15611 set_decoration(var_id, DecorationBinding, kBufferSizeBufferBinding);
15612 resources_in_set[desc_set].push_back(
15613 { &var, to_name(var_id), SPIRType::UInt, get_metal_resource_index(var, SPIRType::UInt), 0 });
15614 }
15615 }
15616 }
15617
15618 // Now add inline uniform blocks.
15619 for (uint32_t var_id : inline_block_vars)
15620 {
15621 auto &var = get<SPIRVariable>(var_id);
15622 uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet);
15623 add_resource_name(var_id);
15624 resources_in_set[desc_set].push_back(
15625 { &var, to_name(var_id), SPIRType::Struct, get_metal_resource_index(var, SPIRType::Struct), 0 });
15626 }
15627
15628 for (uint32_t desc_set = 0; desc_set < kMaxArgumentBuffers; desc_set++)
15629 {
15630 auto &resources = resources_in_set[desc_set];
15631 if (resources.empty())
15632 continue;
15633
15634 assert(descriptor_set_is_argument_buffer(desc_set));
15635
15636 uint32_t next_id = ir.increase_bound_by(3);
15637 uint32_t type_id = next_id + 1;
15638 uint32_t ptr_type_id = next_id + 2;
15639 argument_buffer_ids[desc_set] = next_id;
15640
15641 auto &buffer_type = set<SPIRType>(type_id);
15642
15643 buffer_type.basetype = SPIRType::Struct;
15644
15645 if ((argument_buffer_device_storage_mask & (1u << desc_set)) != 0)
15646 {
15647 buffer_type.storage = StorageClassStorageBuffer;
15648 // Make sure the argument buffer gets marked as const device.
15649 set_decoration(next_id, DecorationNonWritable);
15650 // Need to mark the type as a Block to enable this.
15651 set_decoration(type_id, DecorationBlock);
15652 }
15653 else
15654 buffer_type.storage = StorageClassUniform;
15655
15656 set_name(type_id, join("spvDescriptorSetBuffer", desc_set));
15657
15658 auto &ptr_type = set<SPIRType>(ptr_type_id);
15659 ptr_type = buffer_type;
15660 ptr_type.pointer = true;
15661 ptr_type.pointer_depth++;
15662 ptr_type.parent_type = type_id;
15663
15664 uint32_t buffer_variable_id = next_id;
15665 set<SPIRVariable>(buffer_variable_id, ptr_type_id, StorageClassUniform);
15666 set_name(buffer_variable_id, join("spvDescriptorSet", desc_set));
15667
15668 // Ids must be emitted in ID order.
15669 sort(begin(resources), end(resources), [&](const Resource &lhs, const Resource &rhs) -> bool {
15670 return tie(lhs.index, lhs.basetype) < tie(rhs.index, rhs.basetype);
15671 });
15672
15673 uint32_t member_index = 0;
15674 uint32_t next_arg_buff_index = 0;
15675 for (auto &resource : resources)
15676 {
15677 auto &var = *resource.var;
15678 auto &type = get_variable_data_type(var);
15679
15680 // If needed, synthesize and add padding members.
15681 // member_index and next_arg_buff_index are incremented when padding members are added.
15682 if (msl_options.pad_argument_buffer_resources)
15683 {
15684 while (resource.index > next_arg_buff_index)
15685 {
15686 auto &rez_bind = get_argument_buffer_resource(desc_set, next_arg_buff_index);
15687 switch (rez_bind.basetype)
15688 {
15689 case SPIRType::Void:
15690 case SPIRType::Boolean:
15691 case SPIRType::SByte:
15692 case SPIRType::UByte:
15693 case SPIRType::Short:
15694 case SPIRType::UShort:
15695 case SPIRType::Int:
15696 case SPIRType::UInt:
15697 case SPIRType::Int64:
15698 case SPIRType::UInt64:
15699 case SPIRType::AtomicCounter:
15700 case SPIRType::Half:
15701 case SPIRType::Float:
15702 case SPIRType::Double:
15703 add_argument_buffer_padding_buffer_type(buffer_type, member_index, next_arg_buff_index, rez_bind);
15704 break;
15705 case SPIRType::Image:
15706 add_argument_buffer_padding_image_type(buffer_type, member_index, next_arg_buff_index, rez_bind);
15707 break;
15708 case SPIRType::Sampler:
15709 add_argument_buffer_padding_sampler_type(buffer_type, member_index, next_arg_buff_index, rez_bind);
15710 break;
15711 case SPIRType::SampledImage:
15712 if (next_arg_buff_index == rez_bind.msl_sampler)
15713 add_argument_buffer_padding_sampler_type(buffer_type, member_index, next_arg_buff_index, rez_bind);
15714 else
15715 add_argument_buffer_padding_image_type(buffer_type, member_index, next_arg_buff_index, rez_bind);
15716 break;
15717 default:
15718 break;
15719 }
15720 }
15721
15722 // Adjust the number of slots consumed by current member itself.
15723 // If actual member is an array, allow runtime array resolution as well.
15724 uint32_t elem_cnt = type.array.empty() ? 1 : to_array_size_literal(type);
15725 if (elem_cnt == 0)
15726 elem_cnt = get_resource_array_size(var.self);
15727
15728 next_arg_buff_index += elem_cnt;
15729 }
15730
15731 string mbr_name = ensure_valid_name(resource.name, "m");
15732 if (resource.plane > 0)
15733 mbr_name += join(plane_name_suffix, resource.plane);
15734 set_member_name(buffer_type.self, member_index, mbr_name);
15735
15736 if (resource.basetype == SPIRType::Sampler && type.basetype != SPIRType::Sampler)
15737 {
15738 // Have to synthesize a sampler type here.
15739
15740 bool type_is_array = !type.array.empty();
15741 uint32_t sampler_type_id = ir.increase_bound_by(type_is_array ? 2 : 1);
15742 auto &new_sampler_type = set<SPIRType>(sampler_type_id);
15743 new_sampler_type.basetype = SPIRType::Sampler;
15744 new_sampler_type.storage = StorageClassUniformConstant;
15745
15746 if (type_is_array)
15747 {
15748 uint32_t sampler_type_array_id = sampler_type_id + 1;
15749 auto &sampler_type_array = set<SPIRType>(sampler_type_array_id);
15750 sampler_type_array = new_sampler_type;
15751 sampler_type_array.array = type.array;
15752 sampler_type_array.array_size_literal = type.array_size_literal;
15753 sampler_type_array.parent_type = sampler_type_id;
15754 buffer_type.member_types.push_back(sampler_type_array_id);
15755 }
15756 else
15757 buffer_type.member_types.push_back(sampler_type_id);
15758 }
15759 else
15760 {
15761 uint32_t binding = get_decoration(var.self, DecorationBinding);
15762 SetBindingPair pair = { desc_set, binding };
15763
15764 if (resource.basetype == SPIRType::Image || resource.basetype == SPIRType::Sampler ||
15765 resource.basetype == SPIRType::SampledImage)
15766 {
15767 // Drop pointer information when we emit the resources into a struct.
15768 buffer_type.member_types.push_back(get_variable_data_type_id(var));
15769 if (resource.plane == 0)
15770 set_qualified_name(var.self, join(to_name(buffer_variable_id), ".", mbr_name));
15771 }
15772 else if (buffers_requiring_dynamic_offset.count(pair))
15773 {
15774 // Don't set the qualified name here; we'll define a variable holding the corrected buffer address later.
15775 buffer_type.member_types.push_back(var.basetype);
15776 buffers_requiring_dynamic_offset[pair].second = var.self;
15777 }
15778 else if (inline_uniform_blocks.count(pair))
15779 {
15780 // Put the buffer block itself into the argument buffer.
15781 buffer_type.member_types.push_back(get_variable_data_type_id(var));
15782 set_qualified_name(var.self, join(to_name(buffer_variable_id), ".", mbr_name));
15783 }
15784 else if (atomic_image_vars.count(var.self))
15785 {
15786 // Emulate texture2D atomic operations.
15787 // Don't set the qualified name: it's already set for this variable,
15788 // and the code that references the buffer manually appends "_atomic"
15789 // to the name.
15790 uint32_t offset = ir.increase_bound_by(2);
15791 uint32_t atomic_type_id = offset;
15792 uint32_t type_ptr_id = offset + 1;
15793
15794 SPIRType atomic_type;
15795 atomic_type.basetype = SPIRType::AtomicCounter;
15796 atomic_type.width = 32;
15797 atomic_type.vecsize = 1;
15798 set<SPIRType>(atomic_type_id, atomic_type);
15799
15800 atomic_type.pointer = true;
15801 atomic_type.pointer_depth++;
15802 atomic_type.parent_type = atomic_type_id;
15803 atomic_type.storage = StorageClassStorageBuffer;
15804 auto &atomic_ptr_type = set<SPIRType>(type_ptr_id, atomic_type);
15805 atomic_ptr_type.self = atomic_type_id;
15806
15807 buffer_type.member_types.push_back(type_ptr_id);
15808 }
15809 else
15810 {
15811 // Resources will be declared as pointers not references, so automatically dereference as appropriate.
15812 buffer_type.member_types.push_back(var.basetype);
15813 if (type.array.empty())
15814 set_qualified_name(var.self, join("(*", to_name(buffer_variable_id), ".", mbr_name, ")"));
15815 else
15816 set_qualified_name(var.self, join(to_name(buffer_variable_id), ".", mbr_name));
15817 }
15818 }
15819
15820 set_extended_member_decoration(buffer_type.self, member_index, SPIRVCrossDecorationResourceIndexPrimary,
15821 resource.index);
15822 set_extended_member_decoration(buffer_type.self, member_index, SPIRVCrossDecorationInterfaceOrigID,
15823 var.self);
15824 member_index++;
15825 }
15826 }
15827 }
15828
15829 // Return the resource type of the app-provided resources for the descriptor set,
15830 // that matches the resource index of the argument buffer index.
15831 // This is a two-step lookup, first lookup the resource binding number from the argument buffer index,
15832 // then lookup the resource binding using the binding number.
get_argument_buffer_resource(uint32_t desc_set,uint32_t arg_idx)15833 MSLResourceBinding &CompilerMSL::get_argument_buffer_resource(uint32_t desc_set, uint32_t arg_idx)
15834 {
15835 auto stage = get_entry_point().model;
15836 StageSetBinding arg_idx_tuple = { stage, desc_set, arg_idx };
15837 auto arg_itr = resource_arg_buff_idx_to_binding_number.find(arg_idx_tuple);
15838 if (arg_itr != end(resource_arg_buff_idx_to_binding_number))
15839 {
15840 StageSetBinding bind_tuple = { stage, desc_set, arg_itr->second };
15841 auto bind_itr = resource_bindings.find(bind_tuple);
15842 if (bind_itr != end(resource_bindings))
15843 return bind_itr->second.first;
15844 }
15845 SPIRV_CROSS_THROW("Argument buffer resource base type could not be determined. When padding argument buffer "
15846 "elements, all descriptor set resources must be supplied with a base type by the app.");
15847 }
15848
15849 // Adds an argument buffer padding argument buffer type as one or more members of the struct type at the member index.
15850 // Metal does not support arrays of buffers, so these are emitted as multiple struct members.
add_argument_buffer_padding_buffer_type(SPIRType & struct_type,uint32_t & mbr_idx,uint32_t & arg_buff_index,MSLResourceBinding & rez_bind)15851 void CompilerMSL::add_argument_buffer_padding_buffer_type(SPIRType &struct_type, uint32_t &mbr_idx,
15852 uint32_t &arg_buff_index, MSLResourceBinding &rez_bind)
15853 {
15854 if (!argument_buffer_padding_buffer_type_id)
15855 {
15856 uint32_t buff_type_id = ir.increase_bound_by(2);
15857 auto &buff_type = set<SPIRType>(buff_type_id);
15858 buff_type.basetype = rez_bind.basetype;
15859 buff_type.storage = StorageClassUniformConstant;
15860
15861 uint32_t ptr_type_id = buff_type_id + 1;
15862 auto &ptr_type = set<SPIRType>(ptr_type_id);
15863 ptr_type = buff_type;
15864 ptr_type.pointer = true;
15865 ptr_type.pointer_depth++;
15866 ptr_type.parent_type = buff_type_id;
15867
15868 argument_buffer_padding_buffer_type_id = ptr_type_id;
15869 }
15870
15871 for (uint32_t rez_idx = 0; rez_idx < rez_bind.count; rez_idx++)
15872 add_argument_buffer_padding_type(argument_buffer_padding_buffer_type_id, struct_type, mbr_idx, arg_buff_index, 1);
15873 }
15874
15875 // Adds an argument buffer padding argument image type as a member of the struct type at the member index.
add_argument_buffer_padding_image_type(SPIRType & struct_type,uint32_t & mbr_idx,uint32_t & arg_buff_index,MSLResourceBinding & rez_bind)15876 void CompilerMSL::add_argument_buffer_padding_image_type(SPIRType &struct_type, uint32_t &mbr_idx,
15877 uint32_t &arg_buff_index, MSLResourceBinding &rez_bind)
15878 {
15879 if (!argument_buffer_padding_image_type_id)
15880 {
15881 uint32_t base_type_id = ir.increase_bound_by(2);
15882 auto &base_type = set<SPIRType>(base_type_id);
15883 base_type.basetype = SPIRType::Float;
15884 base_type.width = 32;
15885
15886 uint32_t img_type_id = base_type_id + 1;
15887 auto &img_type = set<SPIRType>(img_type_id);
15888 img_type.basetype = SPIRType::Image;
15889 img_type.storage = StorageClassUniformConstant;
15890
15891 img_type.image.type = base_type_id;
15892 img_type.image.dim = Dim2D;
15893 img_type.image.depth = false;
15894 img_type.image.arrayed = false;
15895 img_type.image.ms = false;
15896 img_type.image.sampled = 1;
15897 img_type.image.format = ImageFormatUnknown;
15898 img_type.image.access = AccessQualifierMax;
15899
15900 argument_buffer_padding_image_type_id = img_type_id;
15901 }
15902
15903 add_argument_buffer_padding_type(argument_buffer_padding_image_type_id, struct_type, mbr_idx, arg_buff_index, rez_bind.count);
15904 }
15905
15906 // Adds an argument buffer padding argument sampler type as a member of the struct type at the member index.
add_argument_buffer_padding_sampler_type(SPIRType & struct_type,uint32_t & mbr_idx,uint32_t & arg_buff_index,MSLResourceBinding & rez_bind)15907 void CompilerMSL::add_argument_buffer_padding_sampler_type(SPIRType &struct_type, uint32_t &mbr_idx,
15908 uint32_t &arg_buff_index, MSLResourceBinding &rez_bind)
15909 {
15910 if (!argument_buffer_padding_sampler_type_id)
15911 {
15912 uint32_t samp_type_id = ir.increase_bound_by(1);
15913 auto &samp_type = set<SPIRType>(samp_type_id);
15914 samp_type.basetype = SPIRType::Sampler;
15915 samp_type.storage = StorageClassUniformConstant;
15916
15917 argument_buffer_padding_sampler_type_id = samp_type_id;
15918 }
15919
15920 add_argument_buffer_padding_type(argument_buffer_padding_sampler_type_id, struct_type, mbr_idx, arg_buff_index, rez_bind.count);
15921 }
15922
15923 // Adds the argument buffer padding argument type as a member of the struct type at the member index.
15924 // Advances both arg_buff_index and mbr_idx to next argument slots.
add_argument_buffer_padding_type(uint32_t mbr_type_id,SPIRType & struct_type,uint32_t & mbr_idx,uint32_t & arg_buff_index,uint32_t count)15925 void CompilerMSL::add_argument_buffer_padding_type(uint32_t mbr_type_id, SPIRType &struct_type, uint32_t &mbr_idx,
15926 uint32_t &arg_buff_index, uint32_t count)
15927 {
15928 uint32_t type_id = mbr_type_id;
15929 if (count > 1)
15930 {
15931 uint32_t ary_type_id = ir.increase_bound_by(1);
15932 auto &ary_type = set<SPIRType>(ary_type_id);
15933 ary_type = get<SPIRType>(type_id);
15934 ary_type.array.push_back(count);
15935 ary_type.array_size_literal.push_back(true);
15936 ary_type.parent_type = type_id;
15937 type_id = ary_type_id;
15938 }
15939
15940 set_member_name(struct_type.self, mbr_idx, join("_m", arg_buff_index, "_pad"));
15941 set_extended_member_decoration(struct_type.self, mbr_idx, SPIRVCrossDecorationResourceIndexPrimary, arg_buff_index);
15942 struct_type.member_types.push_back(type_id);
15943
15944 arg_buff_index += count;
15945 mbr_idx++;
15946 }
15947
activate_argument_buffer_resources()15948 void CompilerMSL::activate_argument_buffer_resources()
15949 {
15950 // For ABI compatibility, force-enable all resources which are part of argument buffers.
15951 ir.for_each_typed_id<SPIRVariable>([&](uint32_t self, const SPIRVariable &) {
15952 if (!has_decoration(self, DecorationDescriptorSet))
15953 return;
15954
15955 uint32_t desc_set = get_decoration(self, DecorationDescriptorSet);
15956 if (descriptor_set_is_argument_buffer(desc_set))
15957 active_interface_variables.insert(self);
15958 });
15959 }
15960
using_builtin_array() const15961 bool CompilerMSL::using_builtin_array() const
15962 {
15963 return msl_options.force_native_arrays || is_using_builtin_array;
15964 }
15965
set_combined_sampler_suffix(const char * suffix)15966 void CompilerMSL::set_combined_sampler_suffix(const char *suffix)
15967 {
15968 sampler_name_suffix = suffix;
15969 }
15970
get_combined_sampler_suffix() const15971 const char *CompilerMSL::get_combined_sampler_suffix() const
15972 {
15973 return sampler_name_suffix.c_str();
15974 }
15975
emit_block_hints(const SPIRBlock &)15976 void CompilerMSL::emit_block_hints(const SPIRBlock &)
15977 {
15978 }
15979
additional_fixed_sample_mask_str() const15980 string CompilerMSL::additional_fixed_sample_mask_str() const
15981 {
15982 char print_buffer[32];
15983 sprintf(print_buffer, "0x%x", msl_options.additional_fixed_sample_mask);
15984 return print_buffer;
15985 }
15986