• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#version 450 core
10
11#define PRECISION ${PRECISION}
12
13#define op1(X) ${OPERATOR1}
14
15#define op2(X, Y) ${OPERATOR2}
16
17${define_active_storage_type(STORAGE)}
18
19#extension GL_EXT_control_flow_attributes : require
20
21layout(std430) buffer;
22
23${layout_declare_tensor(B, "w", "tout", DTYPE, STORAGE)}
24${layout_declare_tensor(B, "r", "tin", DTYPE, STORAGE)}
25
26${layout_declare_ubo(B, "ivec3", "tout_limits")}
27${layout_declare_ubo(B, "ivec4", "tin_sizes")}
28
29layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
30
31layout(constant_id = 3) const int packed_dim = 0;
32layout(constant_id = 4) const int reduce_dim = 0;
33layout(constant_id = 5) const int group_dim = 1;
34
35// A more verbose name would be NWORKERS_PER_GROUP. This describes the number of
36// threads that will co-operate to compute one reduction output. There may be
37// multiple groups computing distinct reduction outputs within one work group.
38#define NWORKERS 4
39
40// Sets an upper limit on the total size of a work group based on how many
41// elements are allocated in the shared memory array below. Each thread in the
42// work group will write into its assigned element in the shared array.
43#define MAX_NTHREADS 16
44
45shared vec4 shared_vecs[MAX_NTHREADS];
46
47#include "indexing_utils.h"
48
49int tid_to_smi(const ivec2 tid) {
50  return tid.x + tid.y * NWORKERS;
51}
52
53/*
54 * The shaders below compute softmax for a tensor. Softmax is an interesting mix
55 * between a reduction operator and a unary elementwise operator, defined as
56 * exp(x) / (sum of exp(x)). The general flow of the computation is:
57 *
58 * First, find the maximum element along the reduction dim. The maximum element
59 * is used to preserve numerical stability, since division of exponents is
60 * translation invariant.
61 *
62 * Next, compute the sum of exp(x - max_element) along the reduction dim.
63 *
64 * Finally, for each element along the reduction dim, we compute the output as
65 * exp(x - max_element) / sum_of_exponents.
66 *
67 * The shaders below also utilize shared memory to have multiple threads help
68 * compute the max and sum reduction operations. A total of NGROUPS x NWORKERS
69 * threads are launched. Each group works on a unique reduction "row", and
70 * within a group NWORKERS threads co-operate to compute the max and sum of one
71 * "row". Each worker in the group is responsible for computing a partial output
72 * of the "row" and uploading it to shared memory; the overall reduction output
73 * can then be determined by aggregating the partial outputs stored in shared
74 * memory.
75 *
76 * As a caveat, this shader does not currently support cases where `batch` > 1
77 * and the reduce dim happens to also be the batch concatenation dim.  To support
78 * this, there will need to be additional logic to set the starting value of
79 * `scan_pos[reduce_dim]`. Since this is not expected to be a common use-case,
80 * supporting this case is left as an exercise for when it is required.
81 *
82 * As a final note, log softmax is supported with this shader as well since via
83 * the op1 and op2 macro definitions. See the corresponding YAML file for more
84 * details.
85 */
86
87/*
88 * Computes softmax where the reduction dim is orthogonal to the packed dim.
89 * This case is simpler because each element of a texel belongs to a separate
90 * reduction dim, meaning we don't have to perform reduction along a texel.
91 */
92void softmax_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) {
93  // shared memory index of this thread
94  const int smi = tid_to_smi(tid);
95  // used to iterate over all shared memory in the group
96  int group_i;
97
98  scan_pos[reduce_dim] = tid.x;
99  vec4 max_elements = load_texel(tin, scan_pos);
100  // This thread computes a partial maximum
101  for (int i = tid.x; i < tin_sizes[reduce_dim];
102       i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) {
103    max_elements = max(max_elements, load_texel(tin, scan_pos));
104  }
105  shared_vecs[smi] = max_elements;
106  barrier();
107  // Iterate over the partial maximums to obtain the overall maximum
108  group_i = tid.y * NWORKERS;
109  max_elements = shared_vecs[group_i++];
110  for (int i = 1; i < NWORKERS; ++i, group_i++) {
111    max_elements = max(max_elements, shared_vecs[group_i]);
112  }
113
114  scan_pos[reduce_dim] = tid.x;
115  vec4 denominators = vec4(0);
116  // Compute partial sum
117  for (int i = tid.x; i < tin_sizes[reduce_dim];
118       i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) {
119    denominators += exp(load_texel(tin, scan_pos) - max_elements);
120  }
121  shared_vecs[smi] = denominators;
122  barrier();
123  // Iterate over the partial sums to obtain the overall sum
124  group_i = tid.y * NWORKERS;
125  denominators = shared_vecs[group_i++];
126  for (int i = 1; i < NWORKERS; ++i, group_i++) {
127    denominators += shared_vecs[group_i];
128  }
129
130  // Determine if there are any padding elements in the final texel of the
131  // packed dimension
132  const int nspill = mod4(tin_sizes[packed_dim]);
133  // Detect if this thread is working on the final texels of the packed
134  // dimension, which may have padding elements
135  const bool is_last_texel =
136      scan_pos[packed_dim] == (tout_limits[packed_dim] - 1);
137
138  scan_pos[reduce_dim] = tid.x;
139  for (int i = tid.x; i < tin_sizes[reduce_dim];
140       i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) {
141    const vec4 numerators = op1(load_texel(tin, scan_pos) - max_elements);
142    vec4 outtex = op2(numerators, denominators);
143    // For the last texel in the packed dim, make sure that the padding elements
144    // are explicitly set to 0. Otherwise, they may influence computations later
145    // down the line.
146    if (is_last_texel && nspill > 0) {
147      [[unroll]] for (int i = nspill; i < 4; ++i) {
148        outtex[i] = 0;
149      }
150    }
151    write_texel(tout, scan_pos, outtex);
152  }
153}
154
155/*
156 * Compute softmax where the reduction dim is also the packed dim. This case is
157 * complex because the reduction needs to occur over the individual texels.
158 * Therefore, in this algorithm each element of the accumulator texels are
159 * themselves partial outputs. Special care has to be taken to ignore padding
160 * elements in texels (which occur when the size of the packed dim is not a
161 * multiple of 4) so that they do not influence the output of reduction.
162 */
163void softmax_packed_dim(const ivec2 tid, ivec3 scan_pos) {
164  // shared memory index of this thread
165  const int smi = tid_to_smi(tid);
166  // used to iterate over all shared memory in the group
167  int group_i;
168
169  const int nspill = mod4(tin_sizes[packed_dim]);
170  const int reduce_len = tin_sizes[packed_dim] - nspill;
171
172  scan_pos[reduce_dim] = tid.x;
173  vec4 max_elements = vec4(load_texel(tin, scan_pos).x);
174  for (int i = tid.x * 4; i < reduce_len;
175       i += NWORKERS * 4, scan_pos[reduce_dim] += NWORKERS) {
176    max_elements = max(max_elements, load_texel(tin, scan_pos));
177  }
178  // For the last texel in the dim, if there are padding elements then each
179  // element of the texel needs to be processed individually such that the
180  // padding elements are ignored
181  if (scan_pos[reduce_dim] == tout_limits[reduce_dim] - 1 && nspill > 0) {
182    const vec4 intex = load_texel(tin, scan_pos);
183    for (int i = 0; i < nspill; ++i) {
184      max_elements.x = max(intex[i], max_elements.x);
185    }
186  }
187  shared_vecs[smi] = max_elements;
188  barrier();
189  // Iterate over the partial maximums to obtain the overall maximum
190  group_i = tid.y * NWORKERS;
191  max_elements = shared_vecs[group_i++];
192  for (int i = 1; i < NWORKERS; ++i, group_i++) {
193    max_elements = max(max_elements, shared_vecs[group_i]);
194  }
195  // Each element of the texel is itself a partial maximum; iterate over the
196  // texel to find the actual maximum
197  float max_element = max_elements.x;
198  [[unroll]] for (int i = 1; i < 4; ++i) {
199    max_element = max(max_elements[i], max_element);
200  }
201
202  scan_pos[reduce_dim] = tid.x;
203  vec4 denominators = vec4(0);
204  for (int i = tid.x * 4; i < reduce_len;
205       i += NWORKERS * 4, scan_pos[reduce_dim] += NWORKERS) {
206    denominators += exp(load_texel(tin, scan_pos) - max_element);
207  }
208  // For the last texel in the dim, if there are padding elements then each
209  // element of the texel needs to be processed individually such that the
210  // padding elements are ignored
211  if (nspill > 0 && scan_pos[reduce_dim] == tout_limits[reduce_dim] - 1) {
212    const vec4 intex = load_texel(tin, scan_pos);
213    for (int i = 0; i < nspill; ++i) {
214      denominators.x += exp(intex[i] - max_element);
215    }
216  }
217  shared_vecs[smi] = denominators;
218  barrier();
219  // Iterate over the partial sums to obtain the overall sum
220  group_i = tid.y * NWORKERS;
221  denominators = shared_vecs[group_i++];
222  for (int i = 1; i < NWORKERS; ++i, group_i++) {
223    denominators += shared_vecs[group_i];
224  }
225  // Reduce over the accumulated texel to find the overall sum
226  float denominator = 0;
227  [[unroll]] for (int i = 0; i < 4; ++i) {
228    denominator += denominators[i];
229  }
230
231  scan_pos[reduce_dim] = tid.x;
232  for (int i = tid.x * 4; i < reduce_len;
233       i += NWORKERS * 4, scan_pos[reduce_dim] += NWORKERS) {
234    const vec4 numerators = op1(load_texel(tin, scan_pos) - max_element);
235    write_texel(tout, scan_pos, op2(numerators, denominator));
236  }
237  // For the last texel in the dim, if there are padding elements then the
238  // padding elements need to be set to 0 explicitly, otherwise they may
239  // influence subsequent operations.
240  if (nspill > 0 && scan_pos[reduce_dim] == tout_limits[reduce_dim] - 1) {
241    const vec4 numerator = op1(load_texel(tin, scan_pos) - max_element);
242    vec4 outtex = op2(numerator, denominator);
243    [[unroll]] for (int i = nspill; i < 4; ++i) {
244      outtex[i] = 0;
245    }
246    write_texel(tout, scan_pos, outtex);
247  }
248}
249
250void main() {
251  ivec3 scan_pos = ivec3(gl_GlobalInvocationID);
252  scan_pos[reduce_dim] = 0;
253
254  const ivec2 tid = ivec2(
255      gl_LocalInvocationID[reduce_dim],
256      gl_LocalInvocationID[group_dim]);
257
258  if (any(greaterThanEqual(scan_pos, tout_limits))) {
259    return;
260  }
261
262  if (reduce_dim != packed_dim) {
263    softmax_nonpacked_dim(tid, scan_pos);
264  } else {
265    softmax_packed_dim(tid, scan_pos);
266  }
267}
268