1/* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9#version 450 core 10 11#define PRECISION ${PRECISION} 12 13#define op1(X) ${OPERATOR1} 14 15#define op2(X, Y) ${OPERATOR2} 16 17${define_active_storage_type(STORAGE)} 18 19#extension GL_EXT_control_flow_attributes : require 20 21layout(std430) buffer; 22 23${layout_declare_tensor(B, "w", "tout", DTYPE, STORAGE)} 24${layout_declare_tensor(B, "r", "tin", DTYPE, STORAGE)} 25 26${layout_declare_ubo(B, "ivec3", "tout_limits")} 27${layout_declare_ubo(B, "ivec4", "tin_sizes")} 28 29layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; 30 31layout(constant_id = 3) const int packed_dim = 0; 32layout(constant_id = 4) const int reduce_dim = 0; 33layout(constant_id = 5) const int group_dim = 1; 34 35// A more verbose name would be NWORKERS_PER_GROUP. This describes the number of 36// threads that will co-operate to compute one reduction output. There may be 37// multiple groups computing distinct reduction outputs within one work group. 38#define NWORKERS 4 39 40// Sets an upper limit on the total size of a work group based on how many 41// elements are allocated in the shared memory array below. Each thread in the 42// work group will write into its assigned element in the shared array. 43#define MAX_NTHREADS 16 44 45shared vec4 shared_vecs[MAX_NTHREADS]; 46 47#include "indexing_utils.h" 48 49int tid_to_smi(const ivec2 tid) { 50 return tid.x + tid.y * NWORKERS; 51} 52 53/* 54 * The shaders below compute softmax for a tensor. Softmax is an interesting mix 55 * between a reduction operator and a unary elementwise operator, defined as 56 * exp(x) / (sum of exp(x)). The general flow of the computation is: 57 * 58 * First, find the maximum element along the reduction dim. The maximum element 59 * is used to preserve numerical stability, since division of exponents is 60 * translation invariant. 61 * 62 * Next, compute the sum of exp(x - max_element) along the reduction dim. 63 * 64 * Finally, for each element along the reduction dim, we compute the output as 65 * exp(x - max_element) / sum_of_exponents. 66 * 67 * The shaders below also utilize shared memory to have multiple threads help 68 * compute the max and sum reduction operations. A total of NGROUPS x NWORKERS 69 * threads are launched. Each group works on a unique reduction "row", and 70 * within a group NWORKERS threads co-operate to compute the max and sum of one 71 * "row". Each worker in the group is responsible for computing a partial output 72 * of the "row" and uploading it to shared memory; the overall reduction output 73 * can then be determined by aggregating the partial outputs stored in shared 74 * memory. 75 * 76 * As a caveat, this shader does not currently support cases where `batch` > 1 77 * and the reduce dim happens to also be the batch concatenation dim. To support 78 * this, there will need to be additional logic to set the starting value of 79 * `scan_pos[reduce_dim]`. Since this is not expected to be a common use-case, 80 * supporting this case is left as an exercise for when it is required. 81 * 82 * As a final note, log softmax is supported with this shader as well since via 83 * the op1 and op2 macro definitions. See the corresponding YAML file for more 84 * details. 85 */ 86 87/* 88 * Computes softmax where the reduction dim is orthogonal to the packed dim. 89 * This case is simpler because each element of a texel belongs to a separate 90 * reduction dim, meaning we don't have to perform reduction along a texel. 91 */ 92void softmax_nonpacked_dim(const ivec2 tid, ivec3 scan_pos) { 93 // shared memory index of this thread 94 const int smi = tid_to_smi(tid); 95 // used to iterate over all shared memory in the group 96 int group_i; 97 98 scan_pos[reduce_dim] = tid.x; 99 vec4 max_elements = load_texel(tin, scan_pos); 100 // This thread computes a partial maximum 101 for (int i = tid.x; i < tin_sizes[reduce_dim]; 102 i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) { 103 max_elements = max(max_elements, load_texel(tin, scan_pos)); 104 } 105 shared_vecs[smi] = max_elements; 106 barrier(); 107 // Iterate over the partial maximums to obtain the overall maximum 108 group_i = tid.y * NWORKERS; 109 max_elements = shared_vecs[group_i++]; 110 for (int i = 1; i < NWORKERS; ++i, group_i++) { 111 max_elements = max(max_elements, shared_vecs[group_i]); 112 } 113 114 scan_pos[reduce_dim] = tid.x; 115 vec4 denominators = vec4(0); 116 // Compute partial sum 117 for (int i = tid.x; i < tin_sizes[reduce_dim]; 118 i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) { 119 denominators += exp(load_texel(tin, scan_pos) - max_elements); 120 } 121 shared_vecs[smi] = denominators; 122 barrier(); 123 // Iterate over the partial sums to obtain the overall sum 124 group_i = tid.y * NWORKERS; 125 denominators = shared_vecs[group_i++]; 126 for (int i = 1; i < NWORKERS; ++i, group_i++) { 127 denominators += shared_vecs[group_i]; 128 } 129 130 // Determine if there are any padding elements in the final texel of the 131 // packed dimension 132 const int nspill = mod4(tin_sizes[packed_dim]); 133 // Detect if this thread is working on the final texels of the packed 134 // dimension, which may have padding elements 135 const bool is_last_texel = 136 scan_pos[packed_dim] == (tout_limits[packed_dim] - 1); 137 138 scan_pos[reduce_dim] = tid.x; 139 for (int i = tid.x; i < tin_sizes[reduce_dim]; 140 i += NWORKERS, scan_pos[reduce_dim] += NWORKERS) { 141 const vec4 numerators = op1(load_texel(tin, scan_pos) - max_elements); 142 vec4 outtex = op2(numerators, denominators); 143 // For the last texel in the packed dim, make sure that the padding elements 144 // are explicitly set to 0. Otherwise, they may influence computations later 145 // down the line. 146 if (is_last_texel && nspill > 0) { 147 [[unroll]] for (int i = nspill; i < 4; ++i) { 148 outtex[i] = 0; 149 } 150 } 151 write_texel(tout, scan_pos, outtex); 152 } 153} 154 155/* 156 * Compute softmax where the reduction dim is also the packed dim. This case is 157 * complex because the reduction needs to occur over the individual texels. 158 * Therefore, in this algorithm each element of the accumulator texels are 159 * themselves partial outputs. Special care has to be taken to ignore padding 160 * elements in texels (which occur when the size of the packed dim is not a 161 * multiple of 4) so that they do not influence the output of reduction. 162 */ 163void softmax_packed_dim(const ivec2 tid, ivec3 scan_pos) { 164 // shared memory index of this thread 165 const int smi = tid_to_smi(tid); 166 // used to iterate over all shared memory in the group 167 int group_i; 168 169 const int nspill = mod4(tin_sizes[packed_dim]); 170 const int reduce_len = tin_sizes[packed_dim] - nspill; 171 172 scan_pos[reduce_dim] = tid.x; 173 vec4 max_elements = vec4(load_texel(tin, scan_pos).x); 174 for (int i = tid.x * 4; i < reduce_len; 175 i += NWORKERS * 4, scan_pos[reduce_dim] += NWORKERS) { 176 max_elements = max(max_elements, load_texel(tin, scan_pos)); 177 } 178 // For the last texel in the dim, if there are padding elements then each 179 // element of the texel needs to be processed individually such that the 180 // padding elements are ignored 181 if (scan_pos[reduce_dim] == tout_limits[reduce_dim] - 1 && nspill > 0) { 182 const vec4 intex = load_texel(tin, scan_pos); 183 for (int i = 0; i < nspill; ++i) { 184 max_elements.x = max(intex[i], max_elements.x); 185 } 186 } 187 shared_vecs[smi] = max_elements; 188 barrier(); 189 // Iterate over the partial maximums to obtain the overall maximum 190 group_i = tid.y * NWORKERS; 191 max_elements = shared_vecs[group_i++]; 192 for (int i = 1; i < NWORKERS; ++i, group_i++) { 193 max_elements = max(max_elements, shared_vecs[group_i]); 194 } 195 // Each element of the texel is itself a partial maximum; iterate over the 196 // texel to find the actual maximum 197 float max_element = max_elements.x; 198 [[unroll]] for (int i = 1; i < 4; ++i) { 199 max_element = max(max_elements[i], max_element); 200 } 201 202 scan_pos[reduce_dim] = tid.x; 203 vec4 denominators = vec4(0); 204 for (int i = tid.x * 4; i < reduce_len; 205 i += NWORKERS * 4, scan_pos[reduce_dim] += NWORKERS) { 206 denominators += exp(load_texel(tin, scan_pos) - max_element); 207 } 208 // For the last texel in the dim, if there are padding elements then each 209 // element of the texel needs to be processed individually such that the 210 // padding elements are ignored 211 if (nspill > 0 && scan_pos[reduce_dim] == tout_limits[reduce_dim] - 1) { 212 const vec4 intex = load_texel(tin, scan_pos); 213 for (int i = 0; i < nspill; ++i) { 214 denominators.x += exp(intex[i] - max_element); 215 } 216 } 217 shared_vecs[smi] = denominators; 218 barrier(); 219 // Iterate over the partial sums to obtain the overall sum 220 group_i = tid.y * NWORKERS; 221 denominators = shared_vecs[group_i++]; 222 for (int i = 1; i < NWORKERS; ++i, group_i++) { 223 denominators += shared_vecs[group_i]; 224 } 225 // Reduce over the accumulated texel to find the overall sum 226 float denominator = 0; 227 [[unroll]] for (int i = 0; i < 4; ++i) { 228 denominator += denominators[i]; 229 } 230 231 scan_pos[reduce_dim] = tid.x; 232 for (int i = tid.x * 4; i < reduce_len; 233 i += NWORKERS * 4, scan_pos[reduce_dim] += NWORKERS) { 234 const vec4 numerators = op1(load_texel(tin, scan_pos) - max_element); 235 write_texel(tout, scan_pos, op2(numerators, denominator)); 236 } 237 // For the last texel in the dim, if there are padding elements then the 238 // padding elements need to be set to 0 explicitly, otherwise they may 239 // influence subsequent operations. 240 if (nspill > 0 && scan_pos[reduce_dim] == tout_limits[reduce_dim] - 1) { 241 const vec4 numerator = op1(load_texel(tin, scan_pos) - max_element); 242 vec4 outtex = op2(numerator, denominator); 243 [[unroll]] for (int i = nspill; i < 4; ++i) { 244 outtex[i] = 0; 245 } 246 write_texel(tout, scan_pos, outtex); 247 } 248} 249 250void main() { 251 ivec3 scan_pos = ivec3(gl_GlobalInvocationID); 252 scan_pos[reduce_dim] = 0; 253 254 const ivec2 tid = ivec2( 255 gl_LocalInvocationID[reduce_dim], 256 gl_LocalInvocationID[group_dim]); 257 258 if (any(greaterThanEqual(scan_pos, tout_limits))) { 259 return; 260 } 261 262 if (reduce_dim != packed_dim) { 263 softmax_nonpacked_dim(tid, scan_pos); 264 } else { 265 softmax_packed_dim(tid, scan_pos); 266 } 267} 268