1/* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9#version 450 core 10 11#define PRECISION ${PRECISION} 12 13#define VEC4_T ${texel_type(DTYPE)} 14 15layout(std430) buffer; 16 17${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} 18${layout_declare_tensor(B, "r", "t_in", "int", STORAGE)} 19${layout_declare_tensor(B, "r", "t_weight", DTYPE, STORAGE)} 20${layout_declare_ubo(B, "ivec4", "sizes")} 21 22#include "indexing_utils.h" 23 24layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; 25 26${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} 27const lowp ivec4 out_axis_map = unhash_axis_map(out_layout); 28const lowp int packed_dim = unhash_packed_dim(out_layout); 29 30${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} 31const lowp ivec4 in_axis_map = unhash_axis_map(in_layout); 32 33${layout_declare_spec_const(C, "int", "weight_layout", "DEFAULT_LAYOUT")} 34const lowp ivec4 weight_axis_map = unhash_axis_map(weight_layout); 35 36void main() { 37 const ivec3 out_lpos = ivec3(gl_GlobalInvocationID); 38 const ivec4 out_tidx = lpos_to_tidx(out_lpos, sizes, out_axis_map.w, packed_dim); 39 if (any(greaterThanEqual(out_tidx, sizes))) { 40 return; 41 } 42 VEC4_T out_texel; 43 44 // Consider optimizing via W-packing format for t_in and t_weight. 45 for (int i = 0; i < 4; ++i) { 46 // Read input tensor for embedding index. 47 const ivec3 in_lpos = ivec3(out_tidx.y, out_tidx.z * 4 + i, out_tidx.w / 4); 48 const int in_texel_elem = load_texel_lpos(t_in, in_lpos, in_axis_map)[out_tidx.w % 4]; 49 50 // Read weight tensor for embedding. 51 const ivec3 weight_lpos = ivec3(out_tidx.x, in_texel_elem, 0); 52 out_texel[i] = load_texel_lpos(t_weight, weight_lpos, weight_axis_map).x; 53 } 54 55 write_texel_lpos(t_out, out_lpos, out_texel, out_axis_map); 56} 57