1 /*
2 * Copyright 2024 Valve Corporation
3 * SPDX-License-Identifier: MIT
4 */
5
6 #include "compiler/nir/nir.h"
7 #include "compiler/nir/nir_builder.h"
8 #include "agx_nir.h"
9
10 /*
11 * AGX generally flushes FP32 denorms. However, the min/max instructions do not
12 * as they are implemented with cmpsel. We need to flush the results of fp32
13 * min/max for correctness. Doing so naively will generate redundant flushes, so
14 * this pass tries to be clever and elide flushes when possible.
15 *
16 * This pass is still pretty simple, it doesn't see through phis or bcsels yet.
17 */
18 static bool
could_be_denorm(nir_scalar s)19 could_be_denorm(nir_scalar s)
20 {
21 /* Constants can be denorms only if they are denorms. */
22 if (nir_scalar_is_const(s)) {
23 return fpclassify(nir_scalar_as_float(s)) == FP_SUBNORMAL;
24 }
25
26 /* Floating-point instructions flush denormals, so ALU results can only be
27 * denormal if they are not from a float instruction. Crucially fmin/fmax
28 * flushes in NIR, so this pass handles chains of fmin/fmax properly.
29 */
30 if (nir_scalar_is_alu(s)) {
31 nir_op op = nir_scalar_alu_op(s);
32 nir_alu_type T = nir_op_infos[op].output_type;
33
34 return nir_alu_type_get_base_type(T) != nir_type_float &&
35 op != nir_op_fmin_agx && op != nir_op_fmax_agx;
36 }
37
38 /* Otherwise, assume it could be denormal (say, loading from a buffer). */
39 return true;
40 }
41
42 static bool
lower(nir_builder * b,nir_alu_instr * alu,void * data)43 lower(nir_builder *b, nir_alu_instr *alu, void *data)
44 {
45 if ((alu->op != nir_op_fmin && alu->op != nir_op_fmax) ||
46 (alu->def.bit_size != 32))
47 return false;
48
49 /* Lower the op, we'll fix up the denorms right after. */
50 if (alu->op == nir_op_fmax)
51 alu->op = nir_op_fmax_agx;
52 else
53 alu->op = nir_op_fmin_agx;
54
55 /* We need to canonicalize the result if the output could be a denorm. That
56 * occurs only when one of the sources could be a denorm. Check each source.
57 * Swizzles don't affect denormalness so we can grab the def directly.
58 */
59 nir_scalar scalar = nir_get_scalar(&alu->def, 0);
60 nir_scalar src0 = nir_scalar_chase_alu_src(scalar, 0);
61 nir_scalar src1 = nir_scalar_chase_alu_src(scalar, 1);
62
63 if (could_be_denorm(src0) || could_be_denorm(src1)) {
64 b->cursor = nir_after_instr(&alu->instr);
65 nir_def *canonicalized = nir_fadd_imm(b, &alu->def, -0.0);
66 nir_def_rewrite_uses_after(&alu->def, canonicalized,
67 canonicalized->parent_instr);
68 }
69
70 return true;
71 }
72
73 bool
agx_nir_lower_fminmax(nir_shader * s)74 agx_nir_lower_fminmax(nir_shader *s)
75 {
76 return nir_shader_alu_pass(s, lower, nir_metadata_control_flow, NULL);
77 }
78