1/* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9#version 450 core 10 11#define PRECISION ${PRECISION} 12 13#define VEC4_T ${texel_type(DTYPE)} 14 15#define op(X, Y, A) ${OPERATOR} 16 17layout(std430) buffer; 18 19${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} 20${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)} 21${layout_declare_tensor(B, "r", "t_other", DTYPE, STORAGE)} 22${layout_declare_ubo(B, "ivec4", "out_sizes")} 23${layout_declare_ubo(B, "ivec4", "in_sizes")} 24${layout_declare_ubo(B, "ivec4", "other_sizes")} 25${layout_declare_ubo(B, "ivec2", "broadcast_params")} 26${layout_declare_ubo(B, "float", "alpha")} 27 28#include "broadcasting_utils.h" 29#include "indexing_utils.h" 30 31layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; 32 33${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} 34const lowp ivec4 out_axis_map = unhash_axis_map(out_layout); 35const lowp int packed_dim = unhash_packed_dim(out_layout); 36 37${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} 38const lowp ivec4 in_axis_map = unhash_axis_map(in_layout); 39 40${layout_declare_spec_const(C, "int", "other_layout", "DEFAULT_LAYOUT")} 41const lowp ivec4 other_axis_map = unhash_axis_map(other_layout); 42 43void main() { 44 const ivec3 lpos = ivec3(gl_GlobalInvocationID); 45 const ivec4 tidx = lpos_to_tidx(lpos, out_sizes, out_axis_map.w, packed_dim); 46 47 if (any(greaterThanEqual(tidx, out_sizes))) { 48 return; 49 } 50 51 // broadcast on logical sizes 52 ivec4 in_idx = broadcast_indices(tidx, in_sizes); 53 VEC4_T in_texel = VEC4_T(load_texel( 54 t_in, 55 // read axis mapped texel 56 tidx_to_pos(in_idx, in_sizes, in_axis_map, packed_dim))); 57 58 // broadcast on logical sizes 59 ivec4 other_idx = broadcast_indices(tidx, other_sizes); 60 VEC4_T other_texel = VEC4_T(load_texel( 61 t_other, 62 // read axis mapped texel 63 tidx_to_pos(other_idx, other_sizes, other_axis_map, packed_dim))); 64 65 // Check boolean broadcast flags; we use ivec2 instead of bvec2 for alignment. 66 if (broadcast_params.x > 0) { 67 in_texel = in_texel.xxxx; 68 } 69 if (broadcast_params.y > 0) { 70 other_texel = other_texel.xxxx; 71 } 72 73 write_texel_lpos( 74 t_out, 75 lpos, 76 VEC4_T(op(in_texel, other_texel, alpha)), 77 out_axis_map); 78} 79