• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Copyright 2023 Intel Corporation
3  * SPDX-License-Identifier: MIT
4  */
5 
6 #include "brw_fs.h"
7 #include "brw_fs_builder.h"
8 
9 using namespace brw;
10 
11 static void
f16_using_mac(const fs_builder & bld,fs_inst * inst)12 f16_using_mac(const fs_builder &bld, fs_inst *inst)
13 {
14    /* We only intend to support configurations where the destination and
15     * accumulator have the same type.
16     */
17    if (!inst->src[0].is_null())
18       assert(inst->dst.type == inst->src[0].type);
19 
20    assert(inst->src[1].type == BRW_REGISTER_TYPE_HF);
21    assert(inst->src[2].type == BRW_REGISTER_TYPE_HF);
22 
23    const brw_reg_type src0_type = inst->dst.type;
24    const brw_reg_type src1_type = BRW_REGISTER_TYPE_HF;
25    const brw_reg_type src2_type = BRW_REGISTER_TYPE_HF;
26 
27    const fs_reg dest = inst->dst;
28    fs_reg src0 = inst->src[0];
29    const fs_reg src1 = retype(inst->src[1], src1_type);
30    const fs_reg src2 = retype(inst->src[2], src2_type);
31 
32    const unsigned dest_stride =
33       dest.type == BRW_REGISTER_TYPE_HF ? REG_SIZE / 2 : REG_SIZE;
34 
35    for (unsigned r = 0; r < inst->rcount; r++) {
36       fs_reg temp = bld.vgrf(BRW_REGISTER_TYPE_HF, 1);
37 
38       for (unsigned subword = 0; subword < 2; subword++) {
39          for (unsigned s = 0; s < inst->sdepth; s++) {
40             /* The first multiply of the dot-product operation has to
41              * explicitly write the accumulator register. The successive MAC
42              * instructions will implicitly read *and* write the
43              * accumulator. Those MAC instructions can also optionally
44              * explicitly write some other register.
45              *
46              * FINISHME: The accumulator can actually hold 16 HF values. On
47              * Gfx12 there are two accumulators. It should be possible to do
48              * this in SIMD16 or even SIMD32. I was unable to get this to work
49              * properly.
50              */
51             if (s == 0 && subword == 0) {
52                const unsigned acc_width = 8;
53                fs_reg acc = suboffset(retype(brw_acc_reg(inst->exec_size), BRW_REGISTER_TYPE_UD),
54                                       inst->group % acc_width);
55 
56                if (bld.shader->devinfo->verx10 >= 125) {
57                   acc = subscript(acc, BRW_REGISTER_TYPE_HF, subword);
58                } else {
59                   acc = retype(acc, BRW_REGISTER_TYPE_HF);
60                }
61 
62                bld.MUL(acc,
63                        subscript(retype(byte_offset(src1, s * REG_SIZE),
64                                         BRW_REGISTER_TYPE_UD),
65                                  BRW_REGISTER_TYPE_HF, subword),
66                        component(retype(byte_offset(src2, r * REG_SIZE),
67                                         BRW_REGISTER_TYPE_HF),
68                                  s * 2 + subword))
69                   ->writes_accumulator = true;
70 
71             } else {
72                fs_reg result;
73 
74                /* As mentioned above, the MAC had an optional, explicit
75                 * destination register. Various optimization passes are not
76                 * clever enough to understand the intricacies of this
77                 * instruction, so only write the result register on the final
78                 * MAC in the sequence.
79                 */
80                if ((s + 1) == inst->sdepth && subword == 1)
81                   result = temp;
82                else
83                   result = retype(bld.null_reg_ud(), BRW_REGISTER_TYPE_HF);
84 
85                bld.MAC(result,
86                        subscript(retype(byte_offset(src1, s * REG_SIZE),
87                                         BRW_REGISTER_TYPE_UD),
88                                  BRW_REGISTER_TYPE_HF, subword),
89                        component(retype(byte_offset(src2, r * REG_SIZE),
90                                         BRW_REGISTER_TYPE_HF),
91                                  s * 2 + subword))
92                   ->writes_accumulator = true;
93             }
94          }
95       }
96 
97       if (!src0.is_null()) {
98          if (src0_type != BRW_REGISTER_TYPE_HF) {
99             fs_reg temp2 = bld.vgrf(src0_type, 1);
100 
101             bld.MOV(temp2, temp);
102 
103             bld.ADD(byte_offset(dest, r * dest_stride),
104                     temp2,
105                     byte_offset(src0, r * dest_stride));
106          } else {
107             bld.ADD(byte_offset(dest, r * dest_stride),
108                     temp,
109                     byte_offset(src0, r * dest_stride));
110          }
111       } else {
112          bld.MOV(byte_offset(dest, r * dest_stride), temp);
113       }
114    }
115 }
116 
117 static void
int8_using_dp4a(const fs_builder & bld,fs_inst * inst)118 int8_using_dp4a(const fs_builder &bld, fs_inst *inst)
119 {
120    /* We only intend to support configurations where the destination and
121     * accumulator have the same type.
122     */
123    if (!inst->src[0].is_null())
124       assert(inst->dst.type == inst->src[0].type);
125 
126    assert(inst->src[1].type == BRW_REGISTER_TYPE_B ||
127           inst->src[1].type == BRW_REGISTER_TYPE_UB);
128    assert(inst->src[2].type == BRW_REGISTER_TYPE_B ||
129           inst->src[2].type == BRW_REGISTER_TYPE_UB);
130 
131    const brw_reg_type src1_type = inst->src[1].type == BRW_REGISTER_TYPE_UB
132       ? BRW_REGISTER_TYPE_UD : BRW_REGISTER_TYPE_D;
133 
134    const brw_reg_type src2_type = inst->src[2].type == BRW_REGISTER_TYPE_UB
135       ? BRW_REGISTER_TYPE_UD : BRW_REGISTER_TYPE_D;
136 
137    fs_reg dest = inst->dst;
138    fs_reg src0 = inst->src[0];
139    const fs_reg src1 = retype(inst->src[1], src1_type);
140    const fs_reg src2 = retype(inst->src[2], src2_type);
141 
142    const unsigned dest_stride = REG_SIZE;
143 
144    for (unsigned r = 0; r < inst->rcount; r++) {
145       if (!src0.is_null()) {
146          bld.MOV(dest, src0);
147          src0 = byte_offset(src0, dest_stride);
148       } else {
149          bld.MOV(dest, retype(brw_imm_d(0), dest.type));
150       }
151 
152       for (unsigned s = 0; s < inst->sdepth; s++) {
153          bld.DP4A(dest,
154                   dest,
155                   byte_offset(src1, s * REG_SIZE),
156                   component(byte_offset(src2, r * REG_SIZE), s))
157             ->saturate = inst->saturate;
158       }
159 
160       dest = byte_offset(dest, dest_stride);
161    }
162 }
163 
164 static void
int8_using_mul_add(const fs_builder & bld,fs_inst * inst)165 int8_using_mul_add(const fs_builder &bld, fs_inst *inst)
166 {
167    /* We only intend to support configurations where the destination and
168     * accumulator have the same type.
169     */
170    if (!inst->src[0].is_null())
171       assert(inst->dst.type == inst->src[0].type);
172 
173    assert(inst->src[1].type == BRW_REGISTER_TYPE_B ||
174           inst->src[1].type == BRW_REGISTER_TYPE_UB);
175    assert(inst->src[2].type == BRW_REGISTER_TYPE_B ||
176           inst->src[2].type == BRW_REGISTER_TYPE_UB);
177 
178    const brw_reg_type src0_type = inst->dst.type;
179 
180    const brw_reg_type src1_type = inst->src[1].type == BRW_REGISTER_TYPE_UB
181       ? BRW_REGISTER_TYPE_UD : BRW_REGISTER_TYPE_D;
182 
183    const brw_reg_type src2_type = inst->src[2].type == BRW_REGISTER_TYPE_UB
184       ? BRW_REGISTER_TYPE_UD : BRW_REGISTER_TYPE_D;
185 
186    fs_reg dest = inst->dst;
187    fs_reg src0 = inst->src[0];
188    const fs_reg src1 = retype(inst->src[1], src1_type);
189    const fs_reg src2 = retype(inst->src[2], src2_type);
190 
191    const unsigned dest_stride = REG_SIZE;
192 
193    for (unsigned r = 0; r < inst->rcount; r++) {
194       if (!src0.is_null()) {
195          bld.MOV(dest, src0);
196          src0 = byte_offset(src0, dest_stride);
197       } else {
198          bld.MOV(dest, retype(brw_imm_d(0), dest.type));
199       }
200 
201       for (unsigned s = 0; s < inst->sdepth; s++) {
202          fs_reg temp1 = bld.vgrf(BRW_REGISTER_TYPE_UD, 1);
203          fs_reg temp2 = bld.vgrf(BRW_REGISTER_TYPE_UD, 1);
204          fs_reg temp3 = bld.vgrf(BRW_REGISTER_TYPE_UD, 2);
205          const brw_reg_type temp_type =
206             (inst->src[1].type == BRW_REGISTER_TYPE_B ||
207              inst->src[2].type == BRW_REGISTER_TYPE_B)
208             ? BRW_REGISTER_TYPE_W : BRW_REGISTER_TYPE_UW;
209 
210          /* Expand 8 dwords of packed bytes into 16 dwords of packed
211           * words.
212           *
213           * FINISHME: Gfx9 should not need this work around. Gfx11
214           * may be able to use integer MAD. Both platforms may be
215           * able to use MAC.
216           */
217          bld.group(32, 0).MOV(retype(temp3, temp_type),
218                               retype(byte_offset(src2, r * REG_SIZE),
219                                      inst->src[2].type));
220 
221          bld.MUL(subscript(temp1, temp_type, 0),
222                  subscript(retype(byte_offset(src1, s * REG_SIZE),
223                                   BRW_REGISTER_TYPE_UD),
224                            inst->src[1].type, 0),
225                  subscript(component(retype(temp3,
226                                             BRW_REGISTER_TYPE_UD),
227                                      s * 2),
228                            temp_type, 0));
229 
230          bld.MUL(subscript(temp1, temp_type, 1),
231                  subscript(retype(byte_offset(src1, s * REG_SIZE),
232                                   BRW_REGISTER_TYPE_UD),
233                            inst->src[1].type, 1),
234                  subscript(component(retype(temp3,
235                                             BRW_REGISTER_TYPE_UD),
236                                      s * 2),
237                            temp_type, 1));
238 
239          bld.MUL(subscript(temp2, temp_type, 0),
240                  subscript(retype(byte_offset(src1, s * REG_SIZE),
241                                   BRW_REGISTER_TYPE_UD),
242                            inst->src[1].type, 2),
243                  subscript(component(retype(temp3,
244                                             BRW_REGISTER_TYPE_UD),
245                                      s * 2 + 1),
246                            temp_type, 0));
247 
248          bld.MUL(subscript(temp2, temp_type, 1),
249                  subscript(retype(byte_offset(src1, s * REG_SIZE),
250                                   BRW_REGISTER_TYPE_UD),
251                            inst->src[1].type, 3),
252                  subscript(component(retype(temp3,
253                                             BRW_REGISTER_TYPE_UD),
254                                      s * 2 + 1),
255                            temp_type, 1));
256 
257          bld.ADD(subscript(temp1, src0_type, 0),
258                  subscript(temp1, temp_type, 0),
259                  subscript(temp1, temp_type, 1));
260 
261          bld.ADD(subscript(temp2, src0_type, 0),
262                  subscript(temp2, temp_type, 0),
263                  subscript(temp2, temp_type, 1));
264 
265          bld.ADD(retype(temp1, src0_type),
266                  retype(temp1, src0_type),
267                  retype(temp2, src0_type));
268 
269          bld.ADD(dest, dest, retype(temp1, src0_type))
270             ->saturate = inst->saturate;
271       }
272 
273       dest = byte_offset(dest, dest_stride);
274    }
275 }
276 
277 bool
brw_lower_dpas(fs_visitor & v)278 brw_lower_dpas(fs_visitor &v)
279 {
280    bool progress = false;
281 
282    foreach_block_and_inst_safe(block, fs_inst, inst, v.cfg) {
283       if (inst->opcode != BRW_OPCODE_DPAS)
284          continue;
285 
286       const fs_builder bld = fs_builder(&v, block, inst).group(8, 0).exec_all();
287 
288       if (brw_reg_type_is_floating_point(inst->dst.type)) {
289          f16_using_mac(bld, inst);
290       } else {
291          if (v.devinfo->ver >= 12) {
292             int8_using_dp4a(bld, inst);
293          } else {
294             int8_using_mul_add(bld, inst);
295          }
296       }
297 
298       inst->remove(block);
299       progress = true;
300    }
301 
302    if (progress)
303       v.invalidate_analysis(DEPENDENCY_INSTRUCTIONS);
304 
305    return progress;
306 }
307