1 /*
2 * Copyright © 2018 Red Hat
3 *
4 * Permission is hereby granted, free of charge, to any person obtaining a
5 * copy of this software and associated documentation files (the "Software"),
6 * to deal in the Software without restriction, including without limitation
7 * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8 * and/or sell copies of the Software, and to permit persons to whom the
9 * Software is furnished to do so, subject to the following conditions:
10 *
11 * The above copyright notice and this permission notice (including the next
12 * paragraph) shall be included in all copies or substantial portions of the
13 * Software.
14 *
15 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
18 * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20 * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21 * IN THE SOFTWARE.
22 *
23 * Authors:
24 * Rob Clark (robdclark@gmail.com)
25 */
26
27 #include "math.h"
28 #include "nir/nir_builtin_builder.h"
29
30 #include "vtn_private.h"
31 #include "OpenCL.std.h"
32
33 typedef nir_ssa_def *(*nir_handler)(struct vtn_builder *b,
34 uint32_t opcode,
35 unsigned num_srcs, nir_ssa_def **srcs,
36 struct vtn_type **src_types,
37 const struct vtn_type *dest_type);
38
to_llvm_address_space(SpvStorageClass mode)39 static int to_llvm_address_space(SpvStorageClass mode)
40 {
41 switch (mode) {
42 case SpvStorageClassPrivate:
43 case SpvStorageClassFunction: return 0;
44 case SpvStorageClassCrossWorkgroup: return 1;
45 case SpvStorageClassUniform:
46 case SpvStorageClassUniformConstant: return 2;
47 case SpvStorageClassWorkgroup: return 3;
48 default: return -1;
49 }
50 }
51
52
53 static void
vtn_opencl_mangle(const char * in_name,uint32_t const_mask,int ntypes,struct vtn_type ** src_types,char ** outstring)54 vtn_opencl_mangle(const char *in_name,
55 uint32_t const_mask,
56 int ntypes, struct vtn_type **src_types,
57 char **outstring)
58 {
59 char local_name[256] = "";
60 char *args_str = local_name + sprintf(local_name, "_Z%zu%s", strlen(in_name), in_name);
61
62 for (unsigned i = 0; i < ntypes; ++i) {
63 const struct glsl_type *type = src_types[i]->type;
64 enum vtn_base_type base_type = src_types[i]->base_type;
65 if (src_types[i]->base_type == vtn_base_type_pointer) {
66 *(args_str++) = 'P';
67 int address_space = to_llvm_address_space(src_types[i]->storage_class);
68 if (address_space > 0)
69 args_str += sprintf(args_str, "U3AS%d", address_space);
70
71 type = src_types[i]->deref->type;
72 base_type = src_types[i]->deref->base_type;
73 }
74
75 if (const_mask & (1 << i))
76 *(args_str++) = 'K';
77
78 unsigned num_elements = glsl_get_components(type);
79 if (num_elements > 1) {
80 /* Vectors are not treated as built-ins for mangling, so check for substitution.
81 * In theory, we'd need to know which substitution value this is. In practice,
82 * the functions we need from libclc only support 1
83 */
84 bool substitution = false;
85 for (unsigned j = 0; j < i; ++j) {
86 const struct glsl_type *other_type = src_types[j]->base_type == vtn_base_type_pointer ?
87 src_types[j]->deref->type : src_types[j]->type;
88 if (type == other_type) {
89 substitution = true;
90 break;
91 }
92 }
93
94 if (substitution) {
95 args_str += sprintf(args_str, "S_");
96 continue;
97 } else
98 args_str += sprintf(args_str, "Dv%d_", num_elements);
99 }
100
101 const char *suffix = NULL;
102 switch (base_type) {
103 case vtn_base_type_sampler: suffix = "11ocl_sampler"; break;
104 case vtn_base_type_event: suffix = "9ocl_event"; break;
105 default: {
106 const char *primitives[] = {
107 [GLSL_TYPE_UINT] = "j",
108 [GLSL_TYPE_INT] = "i",
109 [GLSL_TYPE_FLOAT] = "f",
110 [GLSL_TYPE_FLOAT16] = "Dh",
111 [GLSL_TYPE_DOUBLE] = "d",
112 [GLSL_TYPE_UINT8] = "h",
113 [GLSL_TYPE_INT8] = "c",
114 [GLSL_TYPE_UINT16] = "t",
115 [GLSL_TYPE_INT16] = "s",
116 [GLSL_TYPE_UINT64] = "m",
117 [GLSL_TYPE_INT64] = "l",
118 [GLSL_TYPE_BOOL] = "b",
119 [GLSL_TYPE_ERROR] = NULL,
120 };
121 enum glsl_base_type glsl_base_type = glsl_get_base_type(type);
122 assert(glsl_base_type < ARRAY_SIZE(primitives) && primitives[glsl_base_type]);
123 suffix = primitives[glsl_base_type];
124 break;
125 }
126 }
127 args_str += sprintf(args_str, "%s", suffix);
128 }
129
130 *outstring = strdup(local_name);
131 }
132
mangle_and_find(struct vtn_builder * b,const char * name,uint32_t const_mask,uint32_t num_srcs,struct vtn_type ** src_types)133 static nir_function *mangle_and_find(struct vtn_builder *b,
134 const char *name,
135 uint32_t const_mask,
136 uint32_t num_srcs,
137 struct vtn_type **src_types)
138 {
139 char *mname;
140 nir_function *found = NULL;
141
142 vtn_opencl_mangle(name, const_mask, num_srcs, src_types, &mname);
143 /* try and find in current shader first. */
144 nir_foreach_function(funcs, b->shader) {
145 if (!strcmp(funcs->name, mname)) {
146 found = funcs;
147 break;
148 }
149 }
150 /* if not found here find in clc shader and create a decl mirroring it */
151 if (!found && b->options->clc_shader && b->options->clc_shader != b->shader) {
152 nir_foreach_function(funcs, b->options->clc_shader) {
153 if (!strcmp(funcs->name, mname)) {
154 found = funcs;
155 break;
156 }
157 }
158 if (found) {
159 nir_function *decl = nir_function_create(b->shader, mname);
160 decl->num_params = found->num_params;
161 decl->params = ralloc_array(b->shader, nir_parameter, decl->num_params);
162 for (unsigned i = 0; i < decl->num_params; i++) {
163 decl->params[i] = found->params[i];
164 }
165 found = decl;
166 }
167 }
168 if (!found)
169 vtn_fail("Can't find clc function %s\n", mname);
170 free(mname);
171 return found;
172 }
173
call_mangled_function(struct vtn_builder * b,const char * name,uint32_t const_mask,uint32_t num_srcs,struct vtn_type ** src_types,const struct vtn_type * dest_type,nir_ssa_def ** srcs,nir_deref_instr ** ret_deref_ptr)174 static bool call_mangled_function(struct vtn_builder *b,
175 const char *name,
176 uint32_t const_mask,
177 uint32_t num_srcs,
178 struct vtn_type **src_types,
179 const struct vtn_type *dest_type,
180 nir_ssa_def **srcs,
181 nir_deref_instr **ret_deref_ptr)
182 {
183 nir_function *found = mangle_and_find(b, name, const_mask, num_srcs, src_types);
184 if (!found)
185 return false;
186
187 nir_call_instr *call = nir_call_instr_create(b->shader, found);
188
189 nir_deref_instr *ret_deref = NULL;
190 uint32_t param_idx = 0;
191 if (dest_type) {
192 nir_variable *ret_tmp = nir_local_variable_create(b->nb.impl,
193 glsl_get_bare_type(dest_type->type),
194 "return_tmp");
195 ret_deref = nir_build_deref_var(&b->nb, ret_tmp);
196 call->params[param_idx++] = nir_src_for_ssa(&ret_deref->dest.ssa);
197 }
198
199 for (unsigned i = 0; i < num_srcs; i++)
200 call->params[param_idx++] = nir_src_for_ssa(srcs[i]);
201 nir_builder_instr_insert(&b->nb, &call->instr);
202
203 *ret_deref_ptr = ret_deref;
204 return true;
205 }
206
207 static void
handle_instr(struct vtn_builder * b,uint32_t opcode,const uint32_t * w_src,unsigned num_srcs,const uint32_t * w_dest,nir_handler handler)208 handle_instr(struct vtn_builder *b, uint32_t opcode,
209 const uint32_t *w_src, unsigned num_srcs, const uint32_t *w_dest, nir_handler handler)
210 {
211 struct vtn_type *dest_type = w_dest ? vtn_get_type(b, w_dest[0]) : NULL;
212
213 nir_ssa_def *srcs[5] = { NULL };
214 struct vtn_type *src_types[5] = { NULL };
215 vtn_assert(num_srcs <= ARRAY_SIZE(srcs));
216 for (unsigned i = 0; i < num_srcs; i++) {
217 struct vtn_value *val = vtn_untyped_value(b, w_src[i]);
218 struct vtn_ssa_value *ssa = vtn_ssa_value(b, w_src[i]);
219 srcs[i] = ssa->def;
220 src_types[i] = val->type;
221 }
222
223 nir_ssa_def *result = handler(b, opcode, num_srcs, srcs, src_types, dest_type);
224 if (result) {
225 vtn_push_nir_ssa(b, w_dest[1], result);
226 } else {
227 vtn_assert(dest_type == NULL);
228 }
229 }
230
231 static nir_op
nir_alu_op_for_opencl_opcode(struct vtn_builder * b,enum OpenCLstd_Entrypoints opcode)232 nir_alu_op_for_opencl_opcode(struct vtn_builder *b,
233 enum OpenCLstd_Entrypoints opcode)
234 {
235 switch (opcode) {
236 case OpenCLstd_Fabs: return nir_op_fabs;
237 case OpenCLstd_SAbs: return nir_op_iabs;
238 case OpenCLstd_SAdd_sat: return nir_op_iadd_sat;
239 case OpenCLstd_UAdd_sat: return nir_op_uadd_sat;
240 case OpenCLstd_Ceil: return nir_op_fceil;
241 case OpenCLstd_Floor: return nir_op_ffloor;
242 case OpenCLstd_SHadd: return nir_op_ihadd;
243 case OpenCLstd_UHadd: return nir_op_uhadd;
244 case OpenCLstd_Fmax: return nir_op_fmax;
245 case OpenCLstd_SMax: return nir_op_imax;
246 case OpenCLstd_UMax: return nir_op_umax;
247 case OpenCLstd_Fmin: return nir_op_fmin;
248 case OpenCLstd_SMin: return nir_op_imin;
249 case OpenCLstd_UMin: return nir_op_umin;
250 case OpenCLstd_Mix: return nir_op_flrp;
251 case OpenCLstd_Native_cos: return nir_op_fcos;
252 case OpenCLstd_Native_divide: return nir_op_fdiv;
253 case OpenCLstd_Native_exp2: return nir_op_fexp2;
254 case OpenCLstd_Native_log2: return nir_op_flog2;
255 case OpenCLstd_Native_powr: return nir_op_fpow;
256 case OpenCLstd_Native_recip: return nir_op_frcp;
257 case OpenCLstd_Native_rsqrt: return nir_op_frsq;
258 case OpenCLstd_Native_sin: return nir_op_fsin;
259 case OpenCLstd_Native_sqrt: return nir_op_fsqrt;
260 case OpenCLstd_SMul_hi: return nir_op_imul_high;
261 case OpenCLstd_UMul_hi: return nir_op_umul_high;
262 case OpenCLstd_Popcount: return nir_op_bit_count;
263 case OpenCLstd_SRhadd: return nir_op_irhadd;
264 case OpenCLstd_URhadd: return nir_op_urhadd;
265 case OpenCLstd_Rsqrt: return nir_op_frsq;
266 case OpenCLstd_Sign: return nir_op_fsign;
267 case OpenCLstd_Sqrt: return nir_op_fsqrt;
268 case OpenCLstd_SSub_sat: return nir_op_isub_sat;
269 case OpenCLstd_USub_sat: return nir_op_usub_sat;
270 case OpenCLstd_Trunc: return nir_op_ftrunc;
271 case OpenCLstd_Rint: return nir_op_fround_even;
272 case OpenCLstd_Half_divide: return nir_op_fdiv;
273 case OpenCLstd_Half_recip: return nir_op_frcp;
274 /* uhm... */
275 case OpenCLstd_UAbs: return nir_op_mov;
276 default:
277 vtn_fail("No NIR equivalent");
278 }
279 }
280
281 static nir_ssa_def *
handle_alu(struct vtn_builder * b,uint32_t opcode,unsigned num_srcs,nir_ssa_def ** srcs,struct vtn_type ** src_types,const struct vtn_type * dest_type)282 handle_alu(struct vtn_builder *b, uint32_t opcode,
283 unsigned num_srcs, nir_ssa_def **srcs, struct vtn_type **src_types,
284 const struct vtn_type *dest_type)
285 {
286 nir_ssa_def *ret = nir_build_alu(&b->nb, nir_alu_op_for_opencl_opcode(b, (enum OpenCLstd_Entrypoints)opcode),
287 srcs[0], srcs[1], srcs[2], NULL);
288 if (opcode == OpenCLstd_Popcount)
289 ret = nir_u2u(&b->nb, ret, glsl_get_bit_size(dest_type->type));
290 return ret;
291 }
292
293 #define REMAP(op, str) [OpenCLstd_##op] = { str }
294 static const struct {
295 const char *fn;
296 } remap_table[] = {
297 REMAP(Distance, "distance"),
298 REMAP(Fast_distance, "fast_distance"),
299 REMAP(Fast_length, "fast_length"),
300 REMAP(Fast_normalize, "fast_normalize"),
301 REMAP(Half_rsqrt, "half_rsqrt"),
302 REMAP(Half_sqrt, "half_sqrt"),
303 REMAP(Length, "length"),
304 REMAP(Normalize, "normalize"),
305 REMAP(Degrees, "degrees"),
306 REMAP(Radians, "radians"),
307 REMAP(Rotate, "rotate"),
308 REMAP(Smoothstep, "smoothstep"),
309 REMAP(Step, "step"),
310
311 REMAP(Pow, "pow"),
312 REMAP(Pown, "pown"),
313 REMAP(Powr, "powr"),
314 REMAP(Rootn, "rootn"),
315 REMAP(Modf, "modf"),
316
317 REMAP(Acos, "acos"),
318 REMAP(Acosh, "acosh"),
319 REMAP(Acospi, "acospi"),
320 REMAP(Asin, "asin"),
321 REMAP(Asinh, "asinh"),
322 REMAP(Asinpi, "asinpi"),
323 REMAP(Atan, "atan"),
324 REMAP(Atan2, "atan2"),
325 REMAP(Atanh, "atanh"),
326 REMAP(Atanpi, "atanpi"),
327 REMAP(Atan2pi, "atan2pi"),
328 REMAP(Cos, "cos"),
329 REMAP(Cosh, "cosh"),
330 REMAP(Cospi, "cospi"),
331 REMAP(Sin, "sin"),
332 REMAP(Sinh, "sinh"),
333 REMAP(Sinpi, "sinpi"),
334 REMAP(Tan, "tan"),
335 REMAP(Tanh, "tanh"),
336 REMAP(Tanpi, "tanpi"),
337 REMAP(Sincos, "sincos"),
338 REMAP(Fract, "fract"),
339 REMAP(Frexp, "frexp"),
340 REMAP(Fma, "fma"),
341 REMAP(Fmod, "fmod"),
342
343 REMAP(Half_cos, "cos"),
344 REMAP(Half_exp, "exp"),
345 REMAP(Half_exp2, "exp2"),
346 REMAP(Half_exp10, "exp10"),
347 REMAP(Half_log, "log"),
348 REMAP(Half_log2, "log2"),
349 REMAP(Half_log10, "log10"),
350 REMAP(Half_powr, "powr"),
351 REMAP(Half_sin, "sin"),
352 REMAP(Half_tan, "tan"),
353
354 REMAP(Remainder, "remainder"),
355 REMAP(Remquo, "remquo"),
356 REMAP(Hypot, "hypot"),
357 REMAP(Exp, "exp"),
358 REMAP(Exp2, "exp2"),
359 REMAP(Exp10, "exp10"),
360 REMAP(Expm1, "expm1"),
361 REMAP(Ldexp, "ldexp"),
362
363 REMAP(Ilogb, "ilogb"),
364 REMAP(Log, "log"),
365 REMAP(Log2, "log2"),
366 REMAP(Log10, "log10"),
367 REMAP(Log1p, "log1p"),
368 REMAP(Logb, "logb"),
369
370 REMAP(Cbrt, "cbrt"),
371 REMAP(Erfc, "erfc"),
372 REMAP(Erf, "erf"),
373
374 REMAP(Lgamma, "lgamma"),
375 REMAP(Lgamma_r, "lgamma_r"),
376 REMAP(Tgamma, "tgamma"),
377
378 REMAP(UMad_sat, "mad_sat"),
379 REMAP(SMad_sat, "mad_sat"),
380
381 REMAP(Shuffle, "shuffle"),
382 REMAP(Shuffle2, "shuffle2"),
383 };
384 #undef REMAP
385
remap_clc_opcode(enum OpenCLstd_Entrypoints opcode)386 static const char *remap_clc_opcode(enum OpenCLstd_Entrypoints opcode)
387 {
388 if (opcode >= (sizeof(remap_table) / sizeof(const char *)))
389 return NULL;
390 return remap_table[opcode].fn;
391 }
392
393 static struct vtn_type *
get_vtn_type_for_glsl_type(struct vtn_builder * b,const struct glsl_type * type)394 get_vtn_type_for_glsl_type(struct vtn_builder *b, const struct glsl_type *type)
395 {
396 struct vtn_type *ret = rzalloc(b, struct vtn_type);
397 assert(glsl_type_is_vector_or_scalar(type));
398 ret->type = type;
399 ret->length = glsl_get_vector_elements(type);
400 ret->base_type = glsl_type_is_vector(type) ? vtn_base_type_vector : vtn_base_type_scalar;
401 return ret;
402 }
403
404 static struct vtn_type *
get_pointer_type(struct vtn_builder * b,struct vtn_type * t,SpvStorageClass storage_class)405 get_pointer_type(struct vtn_builder *b, struct vtn_type *t, SpvStorageClass storage_class)
406 {
407 struct vtn_type *ret = rzalloc(b, struct vtn_type);
408 ret->type = nir_address_format_to_glsl_type(
409 vtn_mode_to_address_format(
410 b, vtn_storage_class_to_mode(b, storage_class, NULL, NULL)));
411 ret->base_type = vtn_base_type_pointer;
412 ret->storage_class = storage_class;
413 ret->deref = t;
414 return ret;
415 }
416
417 static struct vtn_type *
get_signed_type(struct vtn_builder * b,struct vtn_type * t)418 get_signed_type(struct vtn_builder *b, struct vtn_type *t)
419 {
420 if (t->base_type == vtn_base_type_pointer) {
421 return get_pointer_type(b, get_signed_type(b, t->deref), t->storage_class);
422 }
423 return get_vtn_type_for_glsl_type(
424 b, glsl_vector_type(glsl_signed_base_type_of(glsl_get_base_type(t->type)),
425 glsl_get_vector_elements(t->type)));
426 }
427
428 static nir_ssa_def *
handle_clc_fn(struct vtn_builder * b,enum OpenCLstd_Entrypoints opcode,int num_srcs,nir_ssa_def ** srcs,struct vtn_type ** src_types,const struct vtn_type * dest_type)429 handle_clc_fn(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
430 int num_srcs,
431 nir_ssa_def **srcs,
432 struct vtn_type **src_types,
433 const struct vtn_type *dest_type)
434 {
435 const char *name = remap_clc_opcode(opcode);
436 if (!name)
437 return NULL;
438
439 /* Some functions which take params end up with uint (or pointer-to-uint) being passed,
440 * which doesn't mangle correctly when the function expects int or pointer-to-int.
441 * See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_unsignedsigned_a_unsigned_versus_signed_integers
442 */
443 int signed_param = -1;
444 switch (opcode) {
445 case OpenCLstd_Frexp:
446 case OpenCLstd_Lgamma_r:
447 case OpenCLstd_Pown:
448 case OpenCLstd_Rootn:
449 case OpenCLstd_Ldexp:
450 signed_param = 1;
451 break;
452 case OpenCLstd_Remquo:
453 signed_param = 2;
454 break;
455 case OpenCLstd_SMad_sat: {
456 /* All parameters need to be converted to signed */
457 src_types[0] = src_types[1] = src_types[2] = get_signed_type(b, src_types[0]);
458 break;
459 }
460 default: break;
461 }
462
463 if (signed_param >= 0) {
464 src_types[signed_param] = get_signed_type(b, src_types[signed_param]);
465 }
466
467 nir_deref_instr *ret_deref = NULL;
468
469 if (!call_mangled_function(b, name, 0, num_srcs, src_types,
470 dest_type, srcs, &ret_deref))
471 return NULL;
472
473 return ret_deref ? nir_load_deref(&b->nb, ret_deref) : NULL;
474 }
475
476 static nir_ssa_def *
handle_special(struct vtn_builder * b,uint32_t opcode,unsigned num_srcs,nir_ssa_def ** srcs,struct vtn_type ** src_types,const struct vtn_type * dest_type)477 handle_special(struct vtn_builder *b, uint32_t opcode,
478 unsigned num_srcs, nir_ssa_def **srcs, struct vtn_type **src_types,
479 const struct vtn_type *dest_type)
480 {
481 nir_builder *nb = &b->nb;
482 enum OpenCLstd_Entrypoints cl_opcode = (enum OpenCLstd_Entrypoints)opcode;
483
484 switch (cl_opcode) {
485 case OpenCLstd_SAbs_diff:
486 /* these works easier in direct NIR */
487 return nir_iabs_diff(nb, srcs[0], srcs[1]);
488 case OpenCLstd_UAbs_diff:
489 return nir_uabs_diff(nb, srcs[0], srcs[1]);
490 case OpenCLstd_Bitselect:
491 return nir_bitselect(nb, srcs[0], srcs[1], srcs[2]);
492 case OpenCLstd_SMad_hi:
493 return nir_imad_hi(nb, srcs[0], srcs[1], srcs[2]);
494 case OpenCLstd_UMad_hi:
495 return nir_umad_hi(nb, srcs[0], srcs[1], srcs[2]);
496 case OpenCLstd_SMul24:
497 return nir_imul24(nb, srcs[0], srcs[1]);
498 case OpenCLstd_UMul24:
499 return nir_umul24(nb, srcs[0], srcs[1]);
500 case OpenCLstd_SMad24:
501 return nir_imad24(nb, srcs[0], srcs[1], srcs[2]);
502 case OpenCLstd_UMad24:
503 return nir_umad24(nb, srcs[0], srcs[1], srcs[2]);
504 case OpenCLstd_FClamp:
505 return nir_fclamp(nb, srcs[0], srcs[1], srcs[2]);
506 case OpenCLstd_SClamp:
507 return nir_iclamp(nb, srcs[0], srcs[1], srcs[2]);
508 case OpenCLstd_UClamp:
509 return nir_uclamp(nb, srcs[0], srcs[1], srcs[2]);
510 case OpenCLstd_Copysign:
511 return nir_copysign(nb, srcs[0], srcs[1]);
512 case OpenCLstd_Cross:
513 if (dest_type->length == 4)
514 return nir_cross4(nb, srcs[0], srcs[1]);
515 return nir_cross3(nb, srcs[0], srcs[1]);
516 case OpenCLstd_Fdim:
517 return nir_fdim(nb, srcs[0], srcs[1]);
518 case OpenCLstd_Fmod:
519 if (nb->shader->options->lower_fmod)
520 break;
521 return nir_fmod(nb, srcs[0], srcs[1]);
522 case OpenCLstd_Mad:
523 return nir_fmad(nb, srcs[0], srcs[1], srcs[2]);
524 case OpenCLstd_Maxmag:
525 return nir_maxmag(nb, srcs[0], srcs[1]);
526 case OpenCLstd_Minmag:
527 return nir_minmag(nb, srcs[0], srcs[1]);
528 case OpenCLstd_Nan:
529 return nir_nan(nb, srcs[0]);
530 case OpenCLstd_Nextafter:
531 return nir_nextafter(nb, srcs[0], srcs[1]);
532 case OpenCLstd_Normalize:
533 return nir_normalize(nb, srcs[0]);
534 case OpenCLstd_Clz:
535 return nir_clz_u(nb, srcs[0]);
536 case OpenCLstd_Ctz:
537 return nir_ctz_u(nb, srcs[0]);
538 case OpenCLstd_Select:
539 return nir_select(nb, srcs[0], srcs[1], srcs[2]);
540 case OpenCLstd_S_Upsample:
541 case OpenCLstd_U_Upsample:
542 /* SPIR-V and CL have different defs for upsample, just implement in nir */
543 return nir_upsample(nb, srcs[0], srcs[1]);
544 case OpenCLstd_Native_exp:
545 return nir_fexp(nb, srcs[0]);
546 case OpenCLstd_Native_exp10:
547 return nir_fexp2(nb, nir_fmul_imm(nb, srcs[0], log(10) / log(2)));
548 case OpenCLstd_Native_log:
549 return nir_flog(nb, srcs[0]);
550 case OpenCLstd_Native_log10:
551 return nir_fmul_imm(nb, nir_flog2(nb, srcs[0]), log(2) / log(10));
552 case OpenCLstd_Native_tan:
553 return nir_ftan(nb, srcs[0]);
554 case OpenCLstd_Ldexp:
555 if (nb->shader->options->lower_ldexp)
556 break;
557 return nir_ldexp(nb, srcs[0], srcs[1]);
558 case OpenCLstd_Fma:
559 /* FIXME: the software implementation only supports fp32 for now. */
560 if (nb->shader->options->lower_ffma32 && srcs[0]->bit_size == 32)
561 break;
562 return nir_ffma(nb, srcs[0], srcs[1], srcs[2]);
563 default:
564 break;
565 }
566
567 nir_ssa_def *ret = handle_clc_fn(b, opcode, num_srcs, srcs, src_types, dest_type);
568 if (!ret)
569 vtn_fail("No NIR equivalent");
570
571 return ret;
572 }
573
574 static nir_ssa_def *
handle_core(struct vtn_builder * b,uint32_t opcode,unsigned num_srcs,nir_ssa_def ** srcs,struct vtn_type ** src_types,const struct vtn_type * dest_type)575 handle_core(struct vtn_builder *b, uint32_t opcode,
576 unsigned num_srcs, nir_ssa_def **srcs, struct vtn_type **src_types,
577 const struct vtn_type *dest_type)
578 {
579 nir_deref_instr *ret_deref = NULL;
580
581 switch ((SpvOp)opcode) {
582 case SpvOpGroupAsyncCopy: {
583 /* Libclc doesn't include 3-component overloads of the async copy functions.
584 * However, the CLC spec says:
585 * async_work_group_copy and async_work_group_strided_copy for 3-component vector types
586 * behave as async_work_group_copy and async_work_group_strided_copy respectively for 4-component
587 * vector types
588 */
589 for (unsigned i = 0; i < num_srcs; ++i) {
590 if (src_types[i]->base_type == vtn_base_type_pointer &&
591 src_types[i]->deref->base_type == vtn_base_type_vector &&
592 src_types[i]->deref->length == 3) {
593 src_types[i] =
594 get_pointer_type(b,
595 get_vtn_type_for_glsl_type(b, glsl_replace_vector_type(src_types[i]->deref->type, 4)),
596 src_types[i]->storage_class);
597 }
598 }
599 if (!call_mangled_function(b, "async_work_group_strided_copy", (1 << 1), num_srcs, src_types, dest_type, srcs, &ret_deref))
600 return NULL;
601 break;
602 }
603 case SpvOpGroupWaitEvents: {
604 src_types[0] = get_vtn_type_for_glsl_type(b, glsl_int_type());
605 if (!call_mangled_function(b, "wait_group_events", 0, num_srcs, src_types, dest_type, srcs, &ret_deref))
606 return NULL;
607 break;
608 }
609 default:
610 return NULL;
611 }
612
613 return ret_deref ? nir_load_deref(&b->nb, ret_deref) : NULL;
614 }
615
616
617 static void
_handle_v_load_store(struct vtn_builder * b,enum OpenCLstd_Entrypoints opcode,const uint32_t * w,unsigned count,bool load,bool vec_aligned,nir_rounding_mode rounding)618 _handle_v_load_store(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
619 const uint32_t *w, unsigned count, bool load,
620 bool vec_aligned, nir_rounding_mode rounding)
621 {
622 struct vtn_type *type;
623 if (load)
624 type = vtn_get_type(b, w[1]);
625 else
626 type = vtn_get_value_type(b, w[5]);
627 unsigned a = load ? 0 : 1;
628
629 enum glsl_base_type base_type = glsl_get_base_type(type->type);
630 unsigned components = glsl_get_vector_elements(type->type);
631
632 nir_ssa_def *offset = vtn_get_nir_ssa(b, w[5 + a]);
633 struct vtn_value *p = vtn_value(b, w[6 + a], vtn_value_type_pointer);
634
635 enum glsl_base_type ptr_base_type =
636 glsl_get_base_type(p->pointer->type->type);
637 if (base_type != ptr_base_type) {
638 vtn_fail_if(ptr_base_type != GLSL_TYPE_FLOAT16 ||
639 (base_type != GLSL_TYPE_FLOAT &&
640 base_type != GLSL_TYPE_DOUBLE),
641 "vload/vstore cannot do type conversion. "
642 "vload/vstore_half can only convert from half to other "
643 "floating-point types.");
644 }
645
646 struct vtn_ssa_value *comps[NIR_MAX_VEC_COMPONENTS];
647 nir_ssa_def *ncomps[NIR_MAX_VEC_COMPONENTS];
648
649 nir_ssa_def *moffset = nir_imul_imm(&b->nb, offset,
650 (vec_aligned && components == 3) ? 4 : components);
651 nir_deref_instr *deref = vtn_pointer_to_deref(b, p->pointer);
652
653 unsigned alignment = vec_aligned ? glsl_get_cl_alignment(type->type) :
654 glsl_get_bit_size(type->type) / 8;
655 deref = nir_alignment_deref_cast(&b->nb, deref, alignment, 0);
656
657 for (int i = 0; i < components; i++) {
658 nir_ssa_def *coffset = nir_iadd_imm(&b->nb, moffset, i);
659 nir_deref_instr *arr_deref = nir_build_deref_ptr_as_array(&b->nb, deref, coffset);
660
661 if (load) {
662 comps[i] = vtn_local_load(b, arr_deref, p->type->access);
663 ncomps[i] = comps[i]->def;
664 if (base_type != ptr_base_type) {
665 assert(ptr_base_type == GLSL_TYPE_FLOAT16 &&
666 (base_type == GLSL_TYPE_FLOAT ||
667 base_type == GLSL_TYPE_DOUBLE));
668 ncomps[i] = nir_f2fN(&b->nb, ncomps[i],
669 glsl_base_type_get_bit_size(base_type));
670 }
671 } else {
672 struct vtn_ssa_value *ssa = vtn_create_ssa_value(b, glsl_scalar_type(base_type));
673 struct vtn_ssa_value *val = vtn_ssa_value(b, w[5]);
674 ssa->def = nir_channel(&b->nb, val->def, i);
675 if (base_type != ptr_base_type) {
676 assert(ptr_base_type == GLSL_TYPE_FLOAT16 &&
677 (base_type == GLSL_TYPE_FLOAT ||
678 base_type == GLSL_TYPE_DOUBLE));
679 if (rounding == nir_rounding_mode_undef) {
680 ssa->def = nir_f2f16(&b->nb, ssa->def);
681 } else {
682 ssa->def = nir_convert_alu_types(&b->nb, ssa->def,
683 nir_type_float,
684 nir_type_float16,
685 rounding, false);
686 }
687 }
688 vtn_local_store(b, ssa, arr_deref, p->type->access);
689 }
690 }
691 if (load) {
692 vtn_push_nir_ssa(b, w[2], nir_vec(&b->nb, ncomps, components));
693 }
694 }
695
696 static void
vtn_handle_opencl_vload(struct vtn_builder * b,enum OpenCLstd_Entrypoints opcode,const uint32_t * w,unsigned count)697 vtn_handle_opencl_vload(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
698 const uint32_t *w, unsigned count)
699 {
700 _handle_v_load_store(b, opcode, w, count, true,
701 opcode == OpenCLstd_Vloada_halfn,
702 nir_rounding_mode_undef);
703 }
704
705 static void
vtn_handle_opencl_vstore(struct vtn_builder * b,enum OpenCLstd_Entrypoints opcode,const uint32_t * w,unsigned count)706 vtn_handle_opencl_vstore(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
707 const uint32_t *w, unsigned count)
708 {
709 _handle_v_load_store(b, opcode, w, count, false,
710 opcode == OpenCLstd_Vstorea_halfn,
711 nir_rounding_mode_undef);
712 }
713
714 static void
vtn_handle_opencl_vstore_half_r(struct vtn_builder * b,enum OpenCLstd_Entrypoints opcode,const uint32_t * w,unsigned count)715 vtn_handle_opencl_vstore_half_r(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
716 const uint32_t *w, unsigned count)
717 {
718 _handle_v_load_store(b, opcode, w, count, false,
719 opcode == OpenCLstd_Vstorea_halfn_r,
720 vtn_rounding_mode_to_nir(b, w[8]));
721 }
722
723 static nir_ssa_def *
handle_printf(struct vtn_builder * b,uint32_t opcode,unsigned num_srcs,nir_ssa_def ** srcs,struct vtn_type ** src_types,const struct vtn_type * dest_type)724 handle_printf(struct vtn_builder *b, uint32_t opcode,
725 unsigned num_srcs, nir_ssa_def **srcs, struct vtn_type **src_types,
726 const struct vtn_type *dest_type)
727 {
728 /* hahah, yeah, right.. */
729 return nir_imm_int(&b->nb, -1);
730 }
731
732 static nir_ssa_def *
handle_round(struct vtn_builder * b,uint32_t opcode,unsigned num_srcs,nir_ssa_def ** srcs,struct vtn_type ** src_types,const struct vtn_type * dest_type)733 handle_round(struct vtn_builder *b, uint32_t opcode,
734 unsigned num_srcs, nir_ssa_def **srcs, struct vtn_type **src_types,
735 const struct vtn_type *dest_type)
736 {
737 nir_ssa_def *src = srcs[0];
738 nir_builder *nb = &b->nb;
739 nir_ssa_def *half = nir_imm_floatN_t(nb, 0.5, src->bit_size);
740 nir_ssa_def *truncated = nir_ftrunc(nb, src);
741 nir_ssa_def *remainder = nir_fsub(nb, src, truncated);
742
743 return nir_bcsel(nb, nir_fge(nb, nir_fabs(nb, remainder), half),
744 nir_fadd(nb, truncated, nir_fsign(nb, src)), truncated);
745 }
746
747 static nir_ssa_def *
handle_shuffle(struct vtn_builder * b,uint32_t opcode,unsigned num_srcs,nir_ssa_def ** srcs,struct vtn_type ** src_types,const struct vtn_type * dest_type)748 handle_shuffle(struct vtn_builder *b, uint32_t opcode,
749 unsigned num_srcs, nir_ssa_def **srcs, struct vtn_type **src_types,
750 const struct vtn_type *dest_type)
751 {
752 struct nir_ssa_def *input = srcs[0];
753 struct nir_ssa_def *mask = srcs[1];
754
755 unsigned out_elems = dest_type->length;
756 nir_ssa_def *outres[NIR_MAX_VEC_COMPONENTS];
757 unsigned in_elems = input->num_components;
758 if (mask->bit_size != 32)
759 mask = nir_u2u32(&b->nb, mask);
760 mask = nir_iand(&b->nb, mask, nir_imm_intN_t(&b->nb, in_elems - 1, mask->bit_size));
761 for (unsigned i = 0; i < out_elems; i++)
762 outres[i] = nir_vector_extract(&b->nb, input, nir_channel(&b->nb, mask, i));
763
764 return nir_vec(&b->nb, outres, out_elems);
765 }
766
767 static nir_ssa_def *
handle_shuffle2(struct vtn_builder * b,uint32_t opcode,unsigned num_srcs,nir_ssa_def ** srcs,struct vtn_type ** src_types,const struct vtn_type * dest_type)768 handle_shuffle2(struct vtn_builder *b, uint32_t opcode,
769 unsigned num_srcs, nir_ssa_def **srcs, struct vtn_type **src_types,
770 const struct vtn_type *dest_type)
771 {
772 struct nir_ssa_def *input0 = srcs[0];
773 struct nir_ssa_def *input1 = srcs[1];
774 struct nir_ssa_def *mask = srcs[2];
775
776 unsigned out_elems = dest_type->length;
777 nir_ssa_def *outres[NIR_MAX_VEC_COMPONENTS];
778 unsigned in_elems = input0->num_components;
779 unsigned total_mask = 2 * in_elems - 1;
780 unsigned half_mask = in_elems - 1;
781 if (mask->bit_size != 32)
782 mask = nir_u2u32(&b->nb, mask);
783 mask = nir_iand(&b->nb, mask, nir_imm_intN_t(&b->nb, total_mask, mask->bit_size));
784 for (unsigned i = 0; i < out_elems; i++) {
785 nir_ssa_def *this_mask = nir_channel(&b->nb, mask, i);
786 nir_ssa_def *vmask = nir_iand(&b->nb, this_mask, nir_imm_intN_t(&b->nb, half_mask, mask->bit_size));
787 nir_ssa_def *val0 = nir_vector_extract(&b->nb, input0, vmask);
788 nir_ssa_def *val1 = nir_vector_extract(&b->nb, input1, vmask);
789 nir_ssa_def *sel = nir_ilt(&b->nb, this_mask, nir_imm_intN_t(&b->nb, in_elems, mask->bit_size));
790 outres[i] = nir_bcsel(&b->nb, sel, val0, val1);
791 }
792 return nir_vec(&b->nb, outres, out_elems);
793 }
794
795 bool
vtn_handle_opencl_instruction(struct vtn_builder * b,SpvOp ext_opcode,const uint32_t * w,unsigned count)796 vtn_handle_opencl_instruction(struct vtn_builder *b, SpvOp ext_opcode,
797 const uint32_t *w, unsigned count)
798 {
799 enum OpenCLstd_Entrypoints cl_opcode = (enum OpenCLstd_Entrypoints) ext_opcode;
800
801 switch (cl_opcode) {
802 case OpenCLstd_Fabs:
803 case OpenCLstd_SAbs:
804 case OpenCLstd_UAbs:
805 case OpenCLstd_SAdd_sat:
806 case OpenCLstd_UAdd_sat:
807 case OpenCLstd_Ceil:
808 case OpenCLstd_Floor:
809 case OpenCLstd_Fmax:
810 case OpenCLstd_SHadd:
811 case OpenCLstd_UHadd:
812 case OpenCLstd_SMax:
813 case OpenCLstd_UMax:
814 case OpenCLstd_Fmin:
815 case OpenCLstd_SMin:
816 case OpenCLstd_UMin:
817 case OpenCLstd_Mix:
818 case OpenCLstd_Native_cos:
819 case OpenCLstd_Native_divide:
820 case OpenCLstd_Native_exp2:
821 case OpenCLstd_Native_log2:
822 case OpenCLstd_Native_powr:
823 case OpenCLstd_Native_recip:
824 case OpenCLstd_Native_rsqrt:
825 case OpenCLstd_Native_sin:
826 case OpenCLstd_Native_sqrt:
827 case OpenCLstd_SMul_hi:
828 case OpenCLstd_UMul_hi:
829 case OpenCLstd_Popcount:
830 case OpenCLstd_SRhadd:
831 case OpenCLstd_URhadd:
832 case OpenCLstd_Rsqrt:
833 case OpenCLstd_Sign:
834 case OpenCLstd_Sqrt:
835 case OpenCLstd_SSub_sat:
836 case OpenCLstd_USub_sat:
837 case OpenCLstd_Trunc:
838 case OpenCLstd_Rint:
839 case OpenCLstd_Half_divide:
840 case OpenCLstd_Half_recip:
841 handle_instr(b, ext_opcode, w + 5, count - 5, w + 1, handle_alu);
842 return true;
843 case OpenCLstd_SAbs_diff:
844 case OpenCLstd_UAbs_diff:
845 case OpenCLstd_SMad_hi:
846 case OpenCLstd_UMad_hi:
847 case OpenCLstd_SMad24:
848 case OpenCLstd_UMad24:
849 case OpenCLstd_SMul24:
850 case OpenCLstd_UMul24:
851 case OpenCLstd_Bitselect:
852 case OpenCLstd_FClamp:
853 case OpenCLstd_SClamp:
854 case OpenCLstd_UClamp:
855 case OpenCLstd_Copysign:
856 case OpenCLstd_Cross:
857 case OpenCLstd_Degrees:
858 case OpenCLstd_Fdim:
859 case OpenCLstd_Fma:
860 case OpenCLstd_Distance:
861 case OpenCLstd_Fast_distance:
862 case OpenCLstd_Fast_length:
863 case OpenCLstd_Fast_normalize:
864 case OpenCLstd_Half_rsqrt:
865 case OpenCLstd_Half_sqrt:
866 case OpenCLstd_Length:
867 case OpenCLstd_Mad:
868 case OpenCLstd_Maxmag:
869 case OpenCLstd_Minmag:
870 case OpenCLstd_Nan:
871 case OpenCLstd_Nextafter:
872 case OpenCLstd_Normalize:
873 case OpenCLstd_Radians:
874 case OpenCLstd_Rotate:
875 case OpenCLstd_Select:
876 case OpenCLstd_Step:
877 case OpenCLstd_Smoothstep:
878 case OpenCLstd_S_Upsample:
879 case OpenCLstd_U_Upsample:
880 case OpenCLstd_Clz:
881 case OpenCLstd_Ctz:
882 case OpenCLstd_Native_exp:
883 case OpenCLstd_Native_exp10:
884 case OpenCLstd_Native_log:
885 case OpenCLstd_Native_log10:
886 case OpenCLstd_Acos:
887 case OpenCLstd_Acosh:
888 case OpenCLstd_Acospi:
889 case OpenCLstd_Asin:
890 case OpenCLstd_Asinh:
891 case OpenCLstd_Asinpi:
892 case OpenCLstd_Atan:
893 case OpenCLstd_Atan2:
894 case OpenCLstd_Atanh:
895 case OpenCLstd_Atanpi:
896 case OpenCLstd_Atan2pi:
897 case OpenCLstd_Fract:
898 case OpenCLstd_Frexp:
899 case OpenCLstd_Exp:
900 case OpenCLstd_Exp2:
901 case OpenCLstd_Expm1:
902 case OpenCLstd_Exp10:
903 case OpenCLstd_Fmod:
904 case OpenCLstd_Ilogb:
905 case OpenCLstd_Log:
906 case OpenCLstd_Log2:
907 case OpenCLstd_Log10:
908 case OpenCLstd_Log1p:
909 case OpenCLstd_Logb:
910 case OpenCLstd_Ldexp:
911 case OpenCLstd_Cos:
912 case OpenCLstd_Cosh:
913 case OpenCLstd_Cospi:
914 case OpenCLstd_Sin:
915 case OpenCLstd_Sinh:
916 case OpenCLstd_Sinpi:
917 case OpenCLstd_Tan:
918 case OpenCLstd_Tanh:
919 case OpenCLstd_Tanpi:
920 case OpenCLstd_Cbrt:
921 case OpenCLstd_Erfc:
922 case OpenCLstd_Erf:
923 case OpenCLstd_Lgamma:
924 case OpenCLstd_Lgamma_r:
925 case OpenCLstd_Tgamma:
926 case OpenCLstd_Pow:
927 case OpenCLstd_Powr:
928 case OpenCLstd_Pown:
929 case OpenCLstd_Rootn:
930 case OpenCLstd_Remainder:
931 case OpenCLstd_Remquo:
932 case OpenCLstd_Hypot:
933 case OpenCLstd_Sincos:
934 case OpenCLstd_Modf:
935 case OpenCLstd_UMad_sat:
936 case OpenCLstd_SMad_sat:
937 case OpenCLstd_Native_tan:
938 case OpenCLstd_Half_cos:
939 case OpenCLstd_Half_exp:
940 case OpenCLstd_Half_exp2:
941 case OpenCLstd_Half_exp10:
942 case OpenCLstd_Half_log:
943 case OpenCLstd_Half_log2:
944 case OpenCLstd_Half_log10:
945 case OpenCLstd_Half_powr:
946 case OpenCLstd_Half_sin:
947 case OpenCLstd_Half_tan:
948 handle_instr(b, ext_opcode, w + 5, count - 5, w + 1, handle_special);
949 return true;
950 case OpenCLstd_Vloadn:
951 case OpenCLstd_Vload_half:
952 case OpenCLstd_Vload_halfn:
953 case OpenCLstd_Vloada_halfn:
954 vtn_handle_opencl_vload(b, cl_opcode, w, count);
955 return true;
956 case OpenCLstd_Vstoren:
957 case OpenCLstd_Vstore_half:
958 case OpenCLstd_Vstore_halfn:
959 case OpenCLstd_Vstorea_halfn:
960 vtn_handle_opencl_vstore(b, cl_opcode, w, count);
961 return true;
962 case OpenCLstd_Vstore_half_r:
963 case OpenCLstd_Vstore_halfn_r:
964 case OpenCLstd_Vstorea_halfn_r:
965 vtn_handle_opencl_vstore_half_r(b, cl_opcode, w, count);
966 return true;
967 case OpenCLstd_Shuffle:
968 handle_instr(b, ext_opcode, w + 5, count - 5, w + 1, handle_shuffle);
969 return true;
970 case OpenCLstd_Shuffle2:
971 handle_instr(b, ext_opcode, w + 5, count - 5, w + 1, handle_shuffle2);
972 return true;
973 case OpenCLstd_Round:
974 handle_instr(b, ext_opcode, w + 5, count - 5, w + 1, handle_round);
975 return true;
976 case OpenCLstd_Printf:
977 handle_instr(b, ext_opcode, w + 5, count - 5, w + 1, handle_printf);
978 return true;
979 case OpenCLstd_Prefetch:
980 /* TODO maybe add a nir instruction for this? */
981 return true;
982 default:
983 vtn_fail("unhandled opencl opc: %u\n", ext_opcode);
984 return false;
985 }
986 }
987
988 bool
vtn_handle_opencl_core_instruction(struct vtn_builder * b,SpvOp opcode,const uint32_t * w,unsigned count)989 vtn_handle_opencl_core_instruction(struct vtn_builder *b, SpvOp opcode,
990 const uint32_t *w, unsigned count)
991 {
992 switch (opcode) {
993 case SpvOpGroupAsyncCopy:
994 handle_instr(b, opcode, w + 4, count - 4, w + 1, handle_core);
995 return true;
996 case SpvOpGroupWaitEvents:
997 handle_instr(b, opcode, w + 2, count - 2, NULL, handle_core);
998 return true;
999 default:
1000 return false;
1001 }
1002 return true;
1003 }
1004