• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright © 2018 Red Hat
3  *
4  * Permission is hereby granted, free of charge, to any person obtaining a
5  * copy of this software and associated documentation files (the "Software"),
6  * to deal in the Software without restriction, including without limitation
7  * the rights to use, copy, modify, merge, publish, distribute, sublicense,
8  * and/or sell copies of the Software, and to permit persons to whom the
9  * Software is furnished to do so, subject to the following conditions:
10  *
11  * The above copyright notice and this permission notice (including the next
12  * paragraph) shall be included in all copies or substantial portions of the
13  * Software.
14  *
15  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL
18  * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
20  * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
21  * IN THE SOFTWARE.
22  *
23  * Authors:
24  *    Rob Clark (robdclark@gmail.com)
25  */
26 
27 #include "math.h"
28 #include "nir/nir_builtin_builder.h"
29 
30 #include "vtn_private.h"
31 #include "OpenCL.std.h"
32 
33 typedef nir_ssa_def *(*nir_handler)(struct vtn_builder *b,
34                                     uint32_t opcode,
35                                     unsigned num_srcs, nir_ssa_def **srcs,
36                                     struct vtn_type **src_types,
37                                     const struct vtn_type *dest_type);
38 
to_llvm_address_space(SpvStorageClass mode)39 static int to_llvm_address_space(SpvStorageClass mode)
40 {
41    switch (mode) {
42    case SpvStorageClassPrivate:
43    case SpvStorageClassFunction: return 0;
44    case SpvStorageClassCrossWorkgroup: return 1;
45    case SpvStorageClassUniform:
46    case SpvStorageClassUniformConstant: return 2;
47    case SpvStorageClassWorkgroup: return 3;
48    default: return -1;
49    }
50 }
51 
52 
53 static void
vtn_opencl_mangle(const char * in_name,uint32_t const_mask,int ntypes,struct vtn_type ** src_types,char ** outstring)54 vtn_opencl_mangle(const char *in_name,
55                   uint32_t const_mask,
56                   int ntypes, struct vtn_type **src_types,
57                   char **outstring)
58 {
59    char local_name[256] = "";
60    char *args_str = local_name + sprintf(local_name, "_Z%zu%s", strlen(in_name), in_name);
61 
62    for (unsigned i = 0; i < ntypes; ++i) {
63       const struct glsl_type *type = src_types[i]->type;
64       enum vtn_base_type base_type = src_types[i]->base_type;
65       if (src_types[i]->base_type == vtn_base_type_pointer) {
66          *(args_str++) = 'P';
67          int address_space = to_llvm_address_space(src_types[i]->storage_class);
68          if (address_space > 0)
69             args_str += sprintf(args_str, "U3AS%d", address_space);
70 
71          type = src_types[i]->deref->type;
72          base_type = src_types[i]->deref->base_type;
73       }
74 
75       if (const_mask & (1 << i))
76          *(args_str++) = 'K';
77 
78       unsigned num_elements = glsl_get_components(type);
79       if (num_elements > 1) {
80          /* Vectors are not treated as built-ins for mangling, so check for substitution.
81           * In theory, we'd need to know which substitution value this is. In practice,
82           * the functions we need from libclc only support 1
83           */
84          bool substitution = false;
85          for (unsigned j = 0; j < i; ++j) {
86             const struct glsl_type *other_type = src_types[j]->base_type == vtn_base_type_pointer ?
87                src_types[j]->deref->type : src_types[j]->type;
88             if (type == other_type) {
89                substitution = true;
90                break;
91             }
92          }
93 
94          if (substitution) {
95             args_str += sprintf(args_str, "S_");
96             continue;
97          } else
98             args_str += sprintf(args_str, "Dv%d_", num_elements);
99       }
100 
101       const char *suffix = NULL;
102       switch (base_type) {
103       case vtn_base_type_sampler: suffix = "11ocl_sampler"; break;
104       case vtn_base_type_event: suffix = "9ocl_event"; break;
105       default: {
106          const char *primitives[] = {
107             [GLSL_TYPE_UINT] = "j",
108             [GLSL_TYPE_INT] = "i",
109             [GLSL_TYPE_FLOAT] = "f",
110             [GLSL_TYPE_FLOAT16] = "Dh",
111             [GLSL_TYPE_DOUBLE] = "d",
112             [GLSL_TYPE_UINT8] = "h",
113             [GLSL_TYPE_INT8] = "c",
114             [GLSL_TYPE_UINT16] = "t",
115             [GLSL_TYPE_INT16] = "s",
116             [GLSL_TYPE_UINT64] = "m",
117             [GLSL_TYPE_INT64] = "l",
118             [GLSL_TYPE_BOOL] = "b",
119             [GLSL_TYPE_ERROR] = NULL,
120          };
121          enum glsl_base_type glsl_base_type = glsl_get_base_type(type);
122          assert(glsl_base_type < ARRAY_SIZE(primitives) && primitives[glsl_base_type]);
123          suffix = primitives[glsl_base_type];
124          break;
125       }
126       }
127       args_str += sprintf(args_str, "%s", suffix);
128    }
129 
130    *outstring = strdup(local_name);
131 }
132 
mangle_and_find(struct vtn_builder * b,const char * name,uint32_t const_mask,uint32_t num_srcs,struct vtn_type ** src_types)133 static nir_function *mangle_and_find(struct vtn_builder *b,
134                                      const char *name,
135                                      uint32_t const_mask,
136                                      uint32_t num_srcs,
137                                      struct vtn_type **src_types)
138 {
139    char *mname;
140    nir_function *found = NULL;
141 
142    vtn_opencl_mangle(name, const_mask, num_srcs, src_types, &mname);
143    /* try and find in current shader first. */
144    nir_foreach_function(funcs, b->shader) {
145       if (!strcmp(funcs->name, mname)) {
146          found = funcs;
147          break;
148       }
149    }
150    /* if not found here find in clc shader and create a decl mirroring it */
151    if (!found && b->options->clc_shader && b->options->clc_shader != b->shader) {
152       nir_foreach_function(funcs, b->options->clc_shader) {
153          if (!strcmp(funcs->name, mname)) {
154             found = funcs;
155             break;
156          }
157       }
158       if (found) {
159          nir_function *decl = nir_function_create(b->shader, mname);
160          decl->num_params = found->num_params;
161          decl->params = ralloc_array(b->shader, nir_parameter, decl->num_params);
162          for (unsigned i = 0; i < decl->num_params; i++) {
163             decl->params[i] = found->params[i];
164          }
165          found = decl;
166       }
167    }
168    if (!found)
169       vtn_fail("Can't find clc function %s\n", mname);
170    free(mname);
171    return found;
172 }
173 
call_mangled_function(struct vtn_builder * b,const char * name,uint32_t const_mask,uint32_t num_srcs,struct vtn_type ** src_types,const struct vtn_type * dest_type,nir_ssa_def ** srcs,nir_deref_instr ** ret_deref_ptr)174 static bool call_mangled_function(struct vtn_builder *b,
175                                   const char *name,
176                                   uint32_t const_mask,
177                                   uint32_t num_srcs,
178                                   struct vtn_type **src_types,
179                                   const struct vtn_type *dest_type,
180                                   nir_ssa_def **srcs,
181                                   nir_deref_instr **ret_deref_ptr)
182 {
183    nir_function *found = mangle_and_find(b, name, const_mask, num_srcs, src_types);
184    if (!found)
185       return false;
186 
187    nir_call_instr *call = nir_call_instr_create(b->shader, found);
188 
189    nir_deref_instr *ret_deref = NULL;
190    uint32_t param_idx = 0;
191    if (dest_type) {
192       nir_variable *ret_tmp = nir_local_variable_create(b->nb.impl,
193                                                         glsl_get_bare_type(dest_type->type),
194                                                         "return_tmp");
195       ret_deref = nir_build_deref_var(&b->nb, ret_tmp);
196       call->params[param_idx++] = nir_src_for_ssa(&ret_deref->dest.ssa);
197    }
198 
199    for (unsigned i = 0; i < num_srcs; i++)
200       call->params[param_idx++] = nir_src_for_ssa(srcs[i]);
201    nir_builder_instr_insert(&b->nb, &call->instr);
202 
203    *ret_deref_ptr = ret_deref;
204    return true;
205 }
206 
207 static void
handle_instr(struct vtn_builder * b,uint32_t opcode,const uint32_t * w_src,unsigned num_srcs,const uint32_t * w_dest,nir_handler handler)208 handle_instr(struct vtn_builder *b, uint32_t opcode,
209              const uint32_t *w_src, unsigned num_srcs, const uint32_t *w_dest, nir_handler handler)
210 {
211    struct vtn_type *dest_type = w_dest ? vtn_get_type(b, w_dest[0]) : NULL;
212 
213    nir_ssa_def *srcs[5] = { NULL };
214    struct vtn_type *src_types[5] = { NULL };
215    vtn_assert(num_srcs <= ARRAY_SIZE(srcs));
216    for (unsigned i = 0; i < num_srcs; i++) {
217       struct vtn_value *val = vtn_untyped_value(b, w_src[i]);
218       struct vtn_ssa_value *ssa = vtn_ssa_value(b, w_src[i]);
219       srcs[i] = ssa->def;
220       src_types[i] = val->type;
221    }
222 
223    nir_ssa_def *result = handler(b, opcode, num_srcs, srcs, src_types, dest_type);
224    if (result) {
225       vtn_push_nir_ssa(b, w_dest[1], result);
226    } else {
227       vtn_assert(dest_type == NULL);
228    }
229 }
230 
231 static nir_op
nir_alu_op_for_opencl_opcode(struct vtn_builder * b,enum OpenCLstd_Entrypoints opcode)232 nir_alu_op_for_opencl_opcode(struct vtn_builder *b,
233                              enum OpenCLstd_Entrypoints opcode)
234 {
235    switch (opcode) {
236    case OpenCLstd_Fabs: return nir_op_fabs;
237    case OpenCLstd_SAbs: return nir_op_iabs;
238    case OpenCLstd_SAdd_sat: return nir_op_iadd_sat;
239    case OpenCLstd_UAdd_sat: return nir_op_uadd_sat;
240    case OpenCLstd_Ceil: return nir_op_fceil;
241    case OpenCLstd_Floor: return nir_op_ffloor;
242    case OpenCLstd_SHadd: return nir_op_ihadd;
243    case OpenCLstd_UHadd: return nir_op_uhadd;
244    case OpenCLstd_Fmax: return nir_op_fmax;
245    case OpenCLstd_SMax: return nir_op_imax;
246    case OpenCLstd_UMax: return nir_op_umax;
247    case OpenCLstd_Fmin: return nir_op_fmin;
248    case OpenCLstd_SMin: return nir_op_imin;
249    case OpenCLstd_UMin: return nir_op_umin;
250    case OpenCLstd_Mix: return nir_op_flrp;
251    case OpenCLstd_Native_cos: return nir_op_fcos;
252    case OpenCLstd_Native_divide: return nir_op_fdiv;
253    case OpenCLstd_Native_exp2: return nir_op_fexp2;
254    case OpenCLstd_Native_log2: return nir_op_flog2;
255    case OpenCLstd_Native_powr: return nir_op_fpow;
256    case OpenCLstd_Native_recip: return nir_op_frcp;
257    case OpenCLstd_Native_rsqrt: return nir_op_frsq;
258    case OpenCLstd_Native_sin: return nir_op_fsin;
259    case OpenCLstd_Native_sqrt: return nir_op_fsqrt;
260    case OpenCLstd_SMul_hi: return nir_op_imul_high;
261    case OpenCLstd_UMul_hi: return nir_op_umul_high;
262    case OpenCLstd_Popcount: return nir_op_bit_count;
263    case OpenCLstd_SRhadd: return nir_op_irhadd;
264    case OpenCLstd_URhadd: return nir_op_urhadd;
265    case OpenCLstd_Rsqrt: return nir_op_frsq;
266    case OpenCLstd_Sign: return nir_op_fsign;
267    case OpenCLstd_Sqrt: return nir_op_fsqrt;
268    case OpenCLstd_SSub_sat: return nir_op_isub_sat;
269    case OpenCLstd_USub_sat: return nir_op_usub_sat;
270    case OpenCLstd_Trunc: return nir_op_ftrunc;
271    case OpenCLstd_Rint: return nir_op_fround_even;
272    case OpenCLstd_Half_divide: return nir_op_fdiv;
273    case OpenCLstd_Half_recip: return nir_op_frcp;
274    /* uhm... */
275    case OpenCLstd_UAbs: return nir_op_mov;
276    default:
277       vtn_fail("No NIR equivalent");
278    }
279 }
280 
281 static nir_ssa_def *
handle_alu(struct vtn_builder * b,uint32_t opcode,unsigned num_srcs,nir_ssa_def ** srcs,struct vtn_type ** src_types,const struct vtn_type * dest_type)282 handle_alu(struct vtn_builder *b, uint32_t opcode,
283            unsigned num_srcs, nir_ssa_def **srcs, struct vtn_type **src_types,
284            const struct vtn_type *dest_type)
285 {
286    nir_ssa_def *ret = nir_build_alu(&b->nb, nir_alu_op_for_opencl_opcode(b, (enum OpenCLstd_Entrypoints)opcode),
287                                     srcs[0], srcs[1], srcs[2], NULL);
288    if (opcode == OpenCLstd_Popcount)
289       ret = nir_u2u(&b->nb, ret, glsl_get_bit_size(dest_type->type));
290    return ret;
291 }
292 
293 #define REMAP(op, str) [OpenCLstd_##op] = { str }
294 static const struct {
295    const char *fn;
296 } remap_table[] = {
297    REMAP(Distance, "distance"),
298    REMAP(Fast_distance, "fast_distance"),
299    REMAP(Fast_length, "fast_length"),
300    REMAP(Fast_normalize, "fast_normalize"),
301    REMAP(Half_rsqrt, "half_rsqrt"),
302    REMAP(Half_sqrt, "half_sqrt"),
303    REMAP(Length, "length"),
304    REMAP(Normalize, "normalize"),
305    REMAP(Degrees, "degrees"),
306    REMAP(Radians, "radians"),
307    REMAP(Rotate, "rotate"),
308    REMAP(Smoothstep, "smoothstep"),
309    REMAP(Step, "step"),
310 
311    REMAP(Pow, "pow"),
312    REMAP(Pown, "pown"),
313    REMAP(Powr, "powr"),
314    REMAP(Rootn, "rootn"),
315    REMAP(Modf, "modf"),
316 
317    REMAP(Acos, "acos"),
318    REMAP(Acosh, "acosh"),
319    REMAP(Acospi, "acospi"),
320    REMAP(Asin, "asin"),
321    REMAP(Asinh, "asinh"),
322    REMAP(Asinpi, "asinpi"),
323    REMAP(Atan, "atan"),
324    REMAP(Atan2, "atan2"),
325    REMAP(Atanh, "atanh"),
326    REMAP(Atanpi, "atanpi"),
327    REMAP(Atan2pi, "atan2pi"),
328    REMAP(Cos, "cos"),
329    REMAP(Cosh, "cosh"),
330    REMAP(Cospi, "cospi"),
331    REMAP(Sin, "sin"),
332    REMAP(Sinh, "sinh"),
333    REMAP(Sinpi, "sinpi"),
334    REMAP(Tan, "tan"),
335    REMAP(Tanh, "tanh"),
336    REMAP(Tanpi, "tanpi"),
337    REMAP(Sincos, "sincos"),
338    REMAP(Fract, "fract"),
339    REMAP(Frexp, "frexp"),
340    REMAP(Fma, "fma"),
341    REMAP(Fmod, "fmod"),
342 
343    REMAP(Half_cos, "cos"),
344    REMAP(Half_exp, "exp"),
345    REMAP(Half_exp2, "exp2"),
346    REMAP(Half_exp10, "exp10"),
347    REMAP(Half_log, "log"),
348    REMAP(Half_log2, "log2"),
349    REMAP(Half_log10, "log10"),
350    REMAP(Half_powr, "powr"),
351    REMAP(Half_sin, "sin"),
352    REMAP(Half_tan, "tan"),
353 
354    REMAP(Remainder, "remainder"),
355    REMAP(Remquo, "remquo"),
356    REMAP(Hypot, "hypot"),
357    REMAP(Exp, "exp"),
358    REMAP(Exp2, "exp2"),
359    REMAP(Exp10, "exp10"),
360    REMAP(Expm1, "expm1"),
361    REMAP(Ldexp, "ldexp"),
362 
363    REMAP(Ilogb, "ilogb"),
364    REMAP(Log, "log"),
365    REMAP(Log2, "log2"),
366    REMAP(Log10, "log10"),
367    REMAP(Log1p, "log1p"),
368    REMAP(Logb, "logb"),
369 
370    REMAP(Cbrt, "cbrt"),
371    REMAP(Erfc, "erfc"),
372    REMAP(Erf, "erf"),
373 
374    REMAP(Lgamma, "lgamma"),
375    REMAP(Lgamma_r, "lgamma_r"),
376    REMAP(Tgamma, "tgamma"),
377 
378    REMAP(UMad_sat, "mad_sat"),
379    REMAP(SMad_sat, "mad_sat"),
380 
381    REMAP(Shuffle, "shuffle"),
382    REMAP(Shuffle2, "shuffle2"),
383 };
384 #undef REMAP
385 
remap_clc_opcode(enum OpenCLstd_Entrypoints opcode)386 static const char *remap_clc_opcode(enum OpenCLstd_Entrypoints opcode)
387 {
388    if (opcode >= (sizeof(remap_table) / sizeof(const char *)))
389       return NULL;
390    return remap_table[opcode].fn;
391 }
392 
393 static struct vtn_type *
get_vtn_type_for_glsl_type(struct vtn_builder * b,const struct glsl_type * type)394 get_vtn_type_for_glsl_type(struct vtn_builder *b, const struct glsl_type *type)
395 {
396    struct vtn_type *ret = rzalloc(b, struct vtn_type);
397    assert(glsl_type_is_vector_or_scalar(type));
398    ret->type = type;
399    ret->length = glsl_get_vector_elements(type);
400    ret->base_type = glsl_type_is_vector(type) ? vtn_base_type_vector : vtn_base_type_scalar;
401    return ret;
402 }
403 
404 static struct vtn_type *
get_pointer_type(struct vtn_builder * b,struct vtn_type * t,SpvStorageClass storage_class)405 get_pointer_type(struct vtn_builder *b, struct vtn_type *t, SpvStorageClass storage_class)
406 {
407    struct vtn_type *ret = rzalloc(b, struct vtn_type);
408    ret->type = nir_address_format_to_glsl_type(
409             vtn_mode_to_address_format(
410                b, vtn_storage_class_to_mode(b, storage_class, NULL, NULL)));
411    ret->base_type = vtn_base_type_pointer;
412    ret->storage_class = storage_class;
413    ret->deref = t;
414    return ret;
415 }
416 
417 static struct vtn_type *
get_signed_type(struct vtn_builder * b,struct vtn_type * t)418 get_signed_type(struct vtn_builder *b, struct vtn_type *t)
419 {
420    if (t->base_type == vtn_base_type_pointer) {
421       return get_pointer_type(b, get_signed_type(b, t->deref), t->storage_class);
422    }
423    return get_vtn_type_for_glsl_type(
424       b, glsl_vector_type(glsl_signed_base_type_of(glsl_get_base_type(t->type)),
425                           glsl_get_vector_elements(t->type)));
426 }
427 
428 static nir_ssa_def *
handle_clc_fn(struct vtn_builder * b,enum OpenCLstd_Entrypoints opcode,int num_srcs,nir_ssa_def ** srcs,struct vtn_type ** src_types,const struct vtn_type * dest_type)429 handle_clc_fn(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
430               int num_srcs,
431               nir_ssa_def **srcs,
432               struct vtn_type **src_types,
433               const struct vtn_type *dest_type)
434 {
435    const char *name = remap_clc_opcode(opcode);
436    if (!name)
437        return NULL;
438 
439    /* Some functions which take params end up with uint (or pointer-to-uint) being passed,
440     * which doesn't mangle correctly when the function expects int or pointer-to-int.
441     * See https://www.khronos.org/registry/spir-v/specs/unified1/SPIRV.html#_a_id_unsignedsigned_a_unsigned_versus_signed_integers
442     */
443    int signed_param = -1;
444    switch (opcode) {
445    case OpenCLstd_Frexp:
446    case OpenCLstd_Lgamma_r:
447    case OpenCLstd_Pown:
448    case OpenCLstd_Rootn:
449    case OpenCLstd_Ldexp:
450       signed_param = 1;
451       break;
452    case OpenCLstd_Remquo:
453       signed_param = 2;
454       break;
455    case OpenCLstd_SMad_sat: {
456       /* All parameters need to be converted to signed */
457       src_types[0] = src_types[1] = src_types[2] = get_signed_type(b, src_types[0]);
458       break;
459    }
460    default: break;
461    }
462 
463    if (signed_param >= 0) {
464       src_types[signed_param] = get_signed_type(b, src_types[signed_param]);
465    }
466 
467    nir_deref_instr *ret_deref = NULL;
468 
469    if (!call_mangled_function(b, name, 0, num_srcs, src_types,
470                               dest_type, srcs, &ret_deref))
471       return NULL;
472 
473    return ret_deref ? nir_load_deref(&b->nb, ret_deref) : NULL;
474 }
475 
476 static nir_ssa_def *
handle_special(struct vtn_builder * b,uint32_t opcode,unsigned num_srcs,nir_ssa_def ** srcs,struct vtn_type ** src_types,const struct vtn_type * dest_type)477 handle_special(struct vtn_builder *b, uint32_t opcode,
478                unsigned num_srcs, nir_ssa_def **srcs, struct vtn_type **src_types,
479                const struct vtn_type *dest_type)
480 {
481    nir_builder *nb = &b->nb;
482    enum OpenCLstd_Entrypoints cl_opcode = (enum OpenCLstd_Entrypoints)opcode;
483 
484    switch (cl_opcode) {
485    case OpenCLstd_SAbs_diff:
486      /* these works easier in direct NIR */
487       return nir_iabs_diff(nb, srcs[0], srcs[1]);
488    case OpenCLstd_UAbs_diff:
489       return nir_uabs_diff(nb, srcs[0], srcs[1]);
490    case OpenCLstd_Bitselect:
491       return nir_bitselect(nb, srcs[0], srcs[1], srcs[2]);
492    case OpenCLstd_SMad_hi:
493       return nir_imad_hi(nb, srcs[0], srcs[1], srcs[2]);
494    case OpenCLstd_UMad_hi:
495       return nir_umad_hi(nb, srcs[0], srcs[1], srcs[2]);
496    case OpenCLstd_SMul24:
497       return nir_imul24(nb, srcs[0], srcs[1]);
498    case OpenCLstd_UMul24:
499       return nir_umul24(nb, srcs[0], srcs[1]);
500    case OpenCLstd_SMad24:
501       return nir_imad24(nb, srcs[0], srcs[1], srcs[2]);
502    case OpenCLstd_UMad24:
503       return nir_umad24(nb, srcs[0], srcs[1], srcs[2]);
504    case OpenCLstd_FClamp:
505       return nir_fclamp(nb, srcs[0], srcs[1], srcs[2]);
506    case OpenCLstd_SClamp:
507       return nir_iclamp(nb, srcs[0], srcs[1], srcs[2]);
508    case OpenCLstd_UClamp:
509       return nir_uclamp(nb, srcs[0], srcs[1], srcs[2]);
510    case OpenCLstd_Copysign:
511       return nir_copysign(nb, srcs[0], srcs[1]);
512    case OpenCLstd_Cross:
513       if (dest_type->length == 4)
514          return nir_cross4(nb, srcs[0], srcs[1]);
515       return nir_cross3(nb, srcs[0], srcs[1]);
516    case OpenCLstd_Fdim:
517       return nir_fdim(nb, srcs[0], srcs[1]);
518    case OpenCLstd_Fmod:
519       if (nb->shader->options->lower_fmod)
520          break;
521       return nir_fmod(nb, srcs[0], srcs[1]);
522    case OpenCLstd_Mad:
523       return nir_fmad(nb, srcs[0], srcs[1], srcs[2]);
524    case OpenCLstd_Maxmag:
525       return nir_maxmag(nb, srcs[0], srcs[1]);
526    case OpenCLstd_Minmag:
527       return nir_minmag(nb, srcs[0], srcs[1]);
528    case OpenCLstd_Nan:
529       return nir_nan(nb, srcs[0]);
530    case OpenCLstd_Nextafter:
531       return nir_nextafter(nb, srcs[0], srcs[1]);
532    case OpenCLstd_Normalize:
533       return nir_normalize(nb, srcs[0]);
534    case OpenCLstd_Clz:
535       return nir_clz_u(nb, srcs[0]);
536    case OpenCLstd_Ctz:
537       return nir_ctz_u(nb, srcs[0]);
538    case OpenCLstd_Select:
539       return nir_select(nb, srcs[0], srcs[1], srcs[2]);
540    case OpenCLstd_S_Upsample:
541    case OpenCLstd_U_Upsample:
542       /* SPIR-V and CL have different defs for upsample, just implement in nir */
543       return nir_upsample(nb, srcs[0], srcs[1]);
544    case OpenCLstd_Native_exp:
545       return nir_fexp(nb, srcs[0]);
546    case OpenCLstd_Native_exp10:
547       return nir_fexp2(nb, nir_fmul_imm(nb, srcs[0], log(10) / log(2)));
548    case OpenCLstd_Native_log:
549       return nir_flog(nb, srcs[0]);
550    case OpenCLstd_Native_log10:
551       return nir_fmul_imm(nb, nir_flog2(nb, srcs[0]), log(2) / log(10));
552    case OpenCLstd_Native_tan:
553       return nir_ftan(nb, srcs[0]);
554    case OpenCLstd_Ldexp:
555       if (nb->shader->options->lower_ldexp)
556          break;
557       return nir_ldexp(nb, srcs[0], srcs[1]);
558    case OpenCLstd_Fma:
559       /* FIXME: the software implementation only supports fp32 for now. */
560       if (nb->shader->options->lower_ffma32 && srcs[0]->bit_size == 32)
561          break;
562       return nir_ffma(nb, srcs[0], srcs[1], srcs[2]);
563    default:
564       break;
565    }
566 
567    nir_ssa_def *ret = handle_clc_fn(b, opcode, num_srcs, srcs, src_types, dest_type);
568    if (!ret)
569       vtn_fail("No NIR equivalent");
570 
571    return ret;
572 }
573 
574 static nir_ssa_def *
handle_core(struct vtn_builder * b,uint32_t opcode,unsigned num_srcs,nir_ssa_def ** srcs,struct vtn_type ** src_types,const struct vtn_type * dest_type)575 handle_core(struct vtn_builder *b, uint32_t opcode,
576             unsigned num_srcs, nir_ssa_def **srcs, struct vtn_type **src_types,
577             const struct vtn_type *dest_type)
578 {
579    nir_deref_instr *ret_deref = NULL;
580 
581    switch ((SpvOp)opcode) {
582    case SpvOpGroupAsyncCopy: {
583       /* Libclc doesn't include 3-component overloads of the async copy functions.
584        * However, the CLC spec says:
585        * async_work_group_copy and async_work_group_strided_copy for 3-component vector types
586        * behave as async_work_group_copy and async_work_group_strided_copy respectively for 4-component
587        * vector types
588        */
589       for (unsigned i = 0; i < num_srcs; ++i) {
590          if (src_types[i]->base_type == vtn_base_type_pointer &&
591              src_types[i]->deref->base_type == vtn_base_type_vector &&
592              src_types[i]->deref->length == 3) {
593             src_types[i] =
594                get_pointer_type(b,
595                                 get_vtn_type_for_glsl_type(b, glsl_replace_vector_type(src_types[i]->deref->type, 4)),
596                                 src_types[i]->storage_class);
597          }
598       }
599       if (!call_mangled_function(b, "async_work_group_strided_copy", (1 << 1), num_srcs, src_types, dest_type, srcs, &ret_deref))
600          return NULL;
601       break;
602    }
603    case SpvOpGroupWaitEvents: {
604       src_types[0] = get_vtn_type_for_glsl_type(b, glsl_int_type());
605       if (!call_mangled_function(b, "wait_group_events", 0, num_srcs, src_types, dest_type, srcs, &ret_deref))
606          return NULL;
607       break;
608    }
609    default:
610       return NULL;
611    }
612 
613    return ret_deref ? nir_load_deref(&b->nb, ret_deref) : NULL;
614 }
615 
616 
617 static void
_handle_v_load_store(struct vtn_builder * b,enum OpenCLstd_Entrypoints opcode,const uint32_t * w,unsigned count,bool load,bool vec_aligned,nir_rounding_mode rounding)618 _handle_v_load_store(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
619                      const uint32_t *w, unsigned count, bool load,
620                      bool vec_aligned, nir_rounding_mode rounding)
621 {
622    struct vtn_type *type;
623    if (load)
624       type = vtn_get_type(b, w[1]);
625    else
626       type = vtn_get_value_type(b, w[5]);
627    unsigned a = load ? 0 : 1;
628 
629    enum glsl_base_type base_type = glsl_get_base_type(type->type);
630    unsigned components = glsl_get_vector_elements(type->type);
631 
632    nir_ssa_def *offset = vtn_get_nir_ssa(b, w[5 + a]);
633    struct vtn_value *p = vtn_value(b, w[6 + a], vtn_value_type_pointer);
634 
635    enum glsl_base_type ptr_base_type =
636       glsl_get_base_type(p->pointer->type->type);
637    if (base_type != ptr_base_type) {
638       vtn_fail_if(ptr_base_type != GLSL_TYPE_FLOAT16 ||
639                   (base_type != GLSL_TYPE_FLOAT &&
640                    base_type != GLSL_TYPE_DOUBLE),
641                   "vload/vstore cannot do type conversion. "
642                   "vload/vstore_half can only convert from half to other "
643                   "floating-point types.");
644    }
645 
646    struct vtn_ssa_value *comps[NIR_MAX_VEC_COMPONENTS];
647    nir_ssa_def *ncomps[NIR_MAX_VEC_COMPONENTS];
648 
649    nir_ssa_def *moffset = nir_imul_imm(&b->nb, offset,
650       (vec_aligned && components == 3) ? 4 : components);
651    nir_deref_instr *deref = vtn_pointer_to_deref(b, p->pointer);
652 
653    unsigned alignment = vec_aligned ? glsl_get_cl_alignment(type->type) :
654                                       glsl_get_bit_size(type->type) / 8;
655    deref = nir_alignment_deref_cast(&b->nb, deref, alignment, 0);
656 
657    for (int i = 0; i < components; i++) {
658       nir_ssa_def *coffset = nir_iadd_imm(&b->nb, moffset, i);
659       nir_deref_instr *arr_deref = nir_build_deref_ptr_as_array(&b->nb, deref, coffset);
660 
661       if (load) {
662          comps[i] = vtn_local_load(b, arr_deref, p->type->access);
663          ncomps[i] = comps[i]->def;
664          if (base_type != ptr_base_type) {
665             assert(ptr_base_type == GLSL_TYPE_FLOAT16 &&
666                    (base_type == GLSL_TYPE_FLOAT ||
667                     base_type == GLSL_TYPE_DOUBLE));
668             ncomps[i] = nir_f2fN(&b->nb, ncomps[i],
669                                  glsl_base_type_get_bit_size(base_type));
670          }
671       } else {
672          struct vtn_ssa_value *ssa = vtn_create_ssa_value(b, glsl_scalar_type(base_type));
673          struct vtn_ssa_value *val = vtn_ssa_value(b, w[5]);
674          ssa->def = nir_channel(&b->nb, val->def, i);
675          if (base_type != ptr_base_type) {
676             assert(ptr_base_type == GLSL_TYPE_FLOAT16 &&
677                    (base_type == GLSL_TYPE_FLOAT ||
678                     base_type == GLSL_TYPE_DOUBLE));
679             if (rounding == nir_rounding_mode_undef) {
680                ssa->def = nir_f2f16(&b->nb, ssa->def);
681             } else {
682                ssa->def = nir_convert_alu_types(&b->nb, ssa->def,
683                                                 nir_type_float,
684                                                 nir_type_float16,
685                                                 rounding, false);
686             }
687          }
688          vtn_local_store(b, ssa, arr_deref, p->type->access);
689       }
690    }
691    if (load) {
692       vtn_push_nir_ssa(b, w[2], nir_vec(&b->nb, ncomps, components));
693    }
694 }
695 
696 static void
vtn_handle_opencl_vload(struct vtn_builder * b,enum OpenCLstd_Entrypoints opcode,const uint32_t * w,unsigned count)697 vtn_handle_opencl_vload(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
698                         const uint32_t *w, unsigned count)
699 {
700    _handle_v_load_store(b, opcode, w, count, true,
701                         opcode == OpenCLstd_Vloada_halfn,
702                         nir_rounding_mode_undef);
703 }
704 
705 static void
vtn_handle_opencl_vstore(struct vtn_builder * b,enum OpenCLstd_Entrypoints opcode,const uint32_t * w,unsigned count)706 vtn_handle_opencl_vstore(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
707                          const uint32_t *w, unsigned count)
708 {
709    _handle_v_load_store(b, opcode, w, count, false,
710                         opcode == OpenCLstd_Vstorea_halfn,
711                         nir_rounding_mode_undef);
712 }
713 
714 static void
vtn_handle_opencl_vstore_half_r(struct vtn_builder * b,enum OpenCLstd_Entrypoints opcode,const uint32_t * w,unsigned count)715 vtn_handle_opencl_vstore_half_r(struct vtn_builder *b, enum OpenCLstd_Entrypoints opcode,
716                                 const uint32_t *w, unsigned count)
717 {
718    _handle_v_load_store(b, opcode, w, count, false,
719                         opcode == OpenCLstd_Vstorea_halfn_r,
720                         vtn_rounding_mode_to_nir(b, w[8]));
721 }
722 
723 static nir_ssa_def *
handle_printf(struct vtn_builder * b,uint32_t opcode,unsigned num_srcs,nir_ssa_def ** srcs,struct vtn_type ** src_types,const struct vtn_type * dest_type)724 handle_printf(struct vtn_builder *b, uint32_t opcode,
725               unsigned num_srcs, nir_ssa_def **srcs, struct vtn_type **src_types,
726               const struct vtn_type *dest_type)
727 {
728    /* hahah, yeah, right.. */
729    return nir_imm_int(&b->nb, -1);
730 }
731 
732 static nir_ssa_def *
handle_round(struct vtn_builder * b,uint32_t opcode,unsigned num_srcs,nir_ssa_def ** srcs,struct vtn_type ** src_types,const struct vtn_type * dest_type)733 handle_round(struct vtn_builder *b, uint32_t opcode,
734              unsigned num_srcs, nir_ssa_def **srcs, struct vtn_type **src_types,
735              const struct vtn_type *dest_type)
736 {
737    nir_ssa_def *src = srcs[0];
738    nir_builder *nb = &b->nb;
739    nir_ssa_def *half = nir_imm_floatN_t(nb, 0.5, src->bit_size);
740    nir_ssa_def *truncated = nir_ftrunc(nb, src);
741    nir_ssa_def *remainder = nir_fsub(nb, src, truncated);
742 
743    return nir_bcsel(nb, nir_fge(nb, nir_fabs(nb, remainder), half),
744                     nir_fadd(nb, truncated, nir_fsign(nb, src)), truncated);
745 }
746 
747 static nir_ssa_def *
handle_shuffle(struct vtn_builder * b,uint32_t opcode,unsigned num_srcs,nir_ssa_def ** srcs,struct vtn_type ** src_types,const struct vtn_type * dest_type)748 handle_shuffle(struct vtn_builder *b, uint32_t opcode,
749                unsigned num_srcs, nir_ssa_def **srcs, struct vtn_type **src_types,
750                const struct vtn_type *dest_type)
751 {
752    struct nir_ssa_def *input = srcs[0];
753    struct nir_ssa_def *mask = srcs[1];
754 
755    unsigned out_elems = dest_type->length;
756    nir_ssa_def *outres[NIR_MAX_VEC_COMPONENTS];
757    unsigned in_elems = input->num_components;
758    if (mask->bit_size != 32)
759       mask = nir_u2u32(&b->nb, mask);
760    mask = nir_iand(&b->nb, mask, nir_imm_intN_t(&b->nb, in_elems - 1, mask->bit_size));
761    for (unsigned i = 0; i < out_elems; i++)
762       outres[i] = nir_vector_extract(&b->nb, input, nir_channel(&b->nb, mask, i));
763 
764    return nir_vec(&b->nb, outres, out_elems);
765 }
766 
767 static nir_ssa_def *
handle_shuffle2(struct vtn_builder * b,uint32_t opcode,unsigned num_srcs,nir_ssa_def ** srcs,struct vtn_type ** src_types,const struct vtn_type * dest_type)768 handle_shuffle2(struct vtn_builder *b, uint32_t opcode,
769                 unsigned num_srcs, nir_ssa_def **srcs, struct vtn_type **src_types,
770                 const struct vtn_type *dest_type)
771 {
772    struct nir_ssa_def *input0 = srcs[0];
773    struct nir_ssa_def *input1 = srcs[1];
774    struct nir_ssa_def *mask = srcs[2];
775 
776    unsigned out_elems = dest_type->length;
777    nir_ssa_def *outres[NIR_MAX_VEC_COMPONENTS];
778    unsigned in_elems = input0->num_components;
779    unsigned total_mask = 2 * in_elems - 1;
780    unsigned half_mask = in_elems - 1;
781    if (mask->bit_size != 32)
782       mask = nir_u2u32(&b->nb, mask);
783    mask = nir_iand(&b->nb, mask, nir_imm_intN_t(&b->nb, total_mask, mask->bit_size));
784    for (unsigned i = 0; i < out_elems; i++) {
785       nir_ssa_def *this_mask = nir_channel(&b->nb, mask, i);
786       nir_ssa_def *vmask = nir_iand(&b->nb, this_mask, nir_imm_intN_t(&b->nb, half_mask, mask->bit_size));
787       nir_ssa_def *val0 = nir_vector_extract(&b->nb, input0, vmask);
788       nir_ssa_def *val1 = nir_vector_extract(&b->nb, input1, vmask);
789       nir_ssa_def *sel = nir_ilt(&b->nb, this_mask, nir_imm_intN_t(&b->nb, in_elems, mask->bit_size));
790       outres[i] = nir_bcsel(&b->nb, sel, val0, val1);
791    }
792    return nir_vec(&b->nb, outres, out_elems);
793 }
794 
795 bool
vtn_handle_opencl_instruction(struct vtn_builder * b,SpvOp ext_opcode,const uint32_t * w,unsigned count)796 vtn_handle_opencl_instruction(struct vtn_builder *b, SpvOp ext_opcode,
797                               const uint32_t *w, unsigned count)
798 {
799    enum OpenCLstd_Entrypoints cl_opcode = (enum OpenCLstd_Entrypoints) ext_opcode;
800 
801    switch (cl_opcode) {
802    case OpenCLstd_Fabs:
803    case OpenCLstd_SAbs:
804    case OpenCLstd_UAbs:
805    case OpenCLstd_SAdd_sat:
806    case OpenCLstd_UAdd_sat:
807    case OpenCLstd_Ceil:
808    case OpenCLstd_Floor:
809    case OpenCLstd_Fmax:
810    case OpenCLstd_SHadd:
811    case OpenCLstd_UHadd:
812    case OpenCLstd_SMax:
813    case OpenCLstd_UMax:
814    case OpenCLstd_Fmin:
815    case OpenCLstd_SMin:
816    case OpenCLstd_UMin:
817    case OpenCLstd_Mix:
818    case OpenCLstd_Native_cos:
819    case OpenCLstd_Native_divide:
820    case OpenCLstd_Native_exp2:
821    case OpenCLstd_Native_log2:
822    case OpenCLstd_Native_powr:
823    case OpenCLstd_Native_recip:
824    case OpenCLstd_Native_rsqrt:
825    case OpenCLstd_Native_sin:
826    case OpenCLstd_Native_sqrt:
827    case OpenCLstd_SMul_hi:
828    case OpenCLstd_UMul_hi:
829    case OpenCLstd_Popcount:
830    case OpenCLstd_SRhadd:
831    case OpenCLstd_URhadd:
832    case OpenCLstd_Rsqrt:
833    case OpenCLstd_Sign:
834    case OpenCLstd_Sqrt:
835    case OpenCLstd_SSub_sat:
836    case OpenCLstd_USub_sat:
837    case OpenCLstd_Trunc:
838    case OpenCLstd_Rint:
839    case OpenCLstd_Half_divide:
840    case OpenCLstd_Half_recip:
841       handle_instr(b, ext_opcode, w + 5, count - 5, w + 1, handle_alu);
842       return true;
843    case OpenCLstd_SAbs_diff:
844    case OpenCLstd_UAbs_diff:
845    case OpenCLstd_SMad_hi:
846    case OpenCLstd_UMad_hi:
847    case OpenCLstd_SMad24:
848    case OpenCLstd_UMad24:
849    case OpenCLstd_SMul24:
850    case OpenCLstd_UMul24:
851    case OpenCLstd_Bitselect:
852    case OpenCLstd_FClamp:
853    case OpenCLstd_SClamp:
854    case OpenCLstd_UClamp:
855    case OpenCLstd_Copysign:
856    case OpenCLstd_Cross:
857    case OpenCLstd_Degrees:
858    case OpenCLstd_Fdim:
859    case OpenCLstd_Fma:
860    case OpenCLstd_Distance:
861    case OpenCLstd_Fast_distance:
862    case OpenCLstd_Fast_length:
863    case OpenCLstd_Fast_normalize:
864    case OpenCLstd_Half_rsqrt:
865    case OpenCLstd_Half_sqrt:
866    case OpenCLstd_Length:
867    case OpenCLstd_Mad:
868    case OpenCLstd_Maxmag:
869    case OpenCLstd_Minmag:
870    case OpenCLstd_Nan:
871    case OpenCLstd_Nextafter:
872    case OpenCLstd_Normalize:
873    case OpenCLstd_Radians:
874    case OpenCLstd_Rotate:
875    case OpenCLstd_Select:
876    case OpenCLstd_Step:
877    case OpenCLstd_Smoothstep:
878    case OpenCLstd_S_Upsample:
879    case OpenCLstd_U_Upsample:
880    case OpenCLstd_Clz:
881    case OpenCLstd_Ctz:
882    case OpenCLstd_Native_exp:
883    case OpenCLstd_Native_exp10:
884    case OpenCLstd_Native_log:
885    case OpenCLstd_Native_log10:
886    case OpenCLstd_Acos:
887    case OpenCLstd_Acosh:
888    case OpenCLstd_Acospi:
889    case OpenCLstd_Asin:
890    case OpenCLstd_Asinh:
891    case OpenCLstd_Asinpi:
892    case OpenCLstd_Atan:
893    case OpenCLstd_Atan2:
894    case OpenCLstd_Atanh:
895    case OpenCLstd_Atanpi:
896    case OpenCLstd_Atan2pi:
897    case OpenCLstd_Fract:
898    case OpenCLstd_Frexp:
899    case OpenCLstd_Exp:
900    case OpenCLstd_Exp2:
901    case OpenCLstd_Expm1:
902    case OpenCLstd_Exp10:
903    case OpenCLstd_Fmod:
904    case OpenCLstd_Ilogb:
905    case OpenCLstd_Log:
906    case OpenCLstd_Log2:
907    case OpenCLstd_Log10:
908    case OpenCLstd_Log1p:
909    case OpenCLstd_Logb:
910    case OpenCLstd_Ldexp:
911    case OpenCLstd_Cos:
912    case OpenCLstd_Cosh:
913    case OpenCLstd_Cospi:
914    case OpenCLstd_Sin:
915    case OpenCLstd_Sinh:
916    case OpenCLstd_Sinpi:
917    case OpenCLstd_Tan:
918    case OpenCLstd_Tanh:
919    case OpenCLstd_Tanpi:
920    case OpenCLstd_Cbrt:
921    case OpenCLstd_Erfc:
922    case OpenCLstd_Erf:
923    case OpenCLstd_Lgamma:
924    case OpenCLstd_Lgamma_r:
925    case OpenCLstd_Tgamma:
926    case OpenCLstd_Pow:
927    case OpenCLstd_Powr:
928    case OpenCLstd_Pown:
929    case OpenCLstd_Rootn:
930    case OpenCLstd_Remainder:
931    case OpenCLstd_Remquo:
932    case OpenCLstd_Hypot:
933    case OpenCLstd_Sincos:
934    case OpenCLstd_Modf:
935    case OpenCLstd_UMad_sat:
936    case OpenCLstd_SMad_sat:
937    case OpenCLstd_Native_tan:
938    case OpenCLstd_Half_cos:
939    case OpenCLstd_Half_exp:
940    case OpenCLstd_Half_exp2:
941    case OpenCLstd_Half_exp10:
942    case OpenCLstd_Half_log:
943    case OpenCLstd_Half_log2:
944    case OpenCLstd_Half_log10:
945    case OpenCLstd_Half_powr:
946    case OpenCLstd_Half_sin:
947    case OpenCLstd_Half_tan:
948       handle_instr(b, ext_opcode, w + 5, count - 5, w + 1, handle_special);
949       return true;
950    case OpenCLstd_Vloadn:
951    case OpenCLstd_Vload_half:
952    case OpenCLstd_Vload_halfn:
953    case OpenCLstd_Vloada_halfn:
954       vtn_handle_opencl_vload(b, cl_opcode, w, count);
955       return true;
956    case OpenCLstd_Vstoren:
957    case OpenCLstd_Vstore_half:
958    case OpenCLstd_Vstore_halfn:
959    case OpenCLstd_Vstorea_halfn:
960       vtn_handle_opencl_vstore(b, cl_opcode, w, count);
961       return true;
962    case OpenCLstd_Vstore_half_r:
963    case OpenCLstd_Vstore_halfn_r:
964    case OpenCLstd_Vstorea_halfn_r:
965       vtn_handle_opencl_vstore_half_r(b, cl_opcode, w, count);
966       return true;
967    case OpenCLstd_Shuffle:
968       handle_instr(b, ext_opcode, w + 5, count - 5, w + 1, handle_shuffle);
969       return true;
970    case OpenCLstd_Shuffle2:
971       handle_instr(b, ext_opcode, w + 5, count - 5, w + 1, handle_shuffle2);
972       return true;
973    case OpenCLstd_Round:
974       handle_instr(b, ext_opcode, w + 5, count - 5, w + 1, handle_round);
975       return true;
976    case OpenCLstd_Printf:
977       handle_instr(b, ext_opcode, w + 5, count - 5, w + 1, handle_printf);
978       return true;
979    case OpenCLstd_Prefetch:
980       /* TODO maybe add a nir instruction for this? */
981       return true;
982    default:
983       vtn_fail("unhandled opencl opc: %u\n", ext_opcode);
984       return false;
985    }
986 }
987 
988 bool
vtn_handle_opencl_core_instruction(struct vtn_builder * b,SpvOp opcode,const uint32_t * w,unsigned count)989 vtn_handle_opencl_core_instruction(struct vtn_builder *b, SpvOp opcode,
990                                    const uint32_t *w, unsigned count)
991 {
992    switch (opcode) {
993    case SpvOpGroupAsyncCopy:
994       handle_instr(b, opcode, w + 4, count - 4, w + 1, handle_core);
995       return true;
996    case SpvOpGroupWaitEvents:
997       handle_instr(b, opcode, w + 2, count - 2, NULL, handle_core);
998       return true;
999    default:
1000       return false;
1001    }
1002    return true;
1003 }
1004