• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#version 450 core
10
11#define PRECISION ${PRECISION}
12
13#define VEC4_T ${texel_type(DTYPE)}
14
15layout(std430) buffer;
16
17${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
18${layout_declare_tensor(B, "r", "t_in", "int", STORAGE)}
19${layout_declare_tensor(B, "r", "t_weight", DTYPE, STORAGE)}
20${layout_declare_ubo(B, "ivec4", "sizes")}
21
22#include "indexing_utils.h"
23
24layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
25
26${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")}
27const lowp ivec4 out_axis_map = unhash_axis_map(out_layout);
28const lowp int packed_dim = unhash_packed_dim(out_layout);
29
30${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")}
31const lowp ivec4 in_axis_map = unhash_axis_map(in_layout);
32
33${layout_declare_spec_const(C, "int", "weight_layout", "DEFAULT_LAYOUT")}
34const lowp ivec4 weight_axis_map = unhash_axis_map(weight_layout);
35
36void main() {
37  const ivec3 out_lpos = ivec3(gl_GlobalInvocationID);
38  const ivec4 out_tidx = lpos_to_tidx(out_lpos, sizes, out_axis_map.w, packed_dim);
39  if (any(greaterThanEqual(out_tidx, sizes))) {
40    return;
41  }
42  VEC4_T out_texel;
43
44  // Consider optimizing via W-packing format for t_in and t_weight.
45  for (int i = 0; i < 4; ++i) {
46    // Read input tensor for embedding index.
47    const ivec3 in_lpos = ivec3(out_tidx.y, out_tidx.z * 4 + i, out_tidx.w / 4);
48    const int in_texel_elem = load_texel_lpos(t_in, in_lpos, in_axis_map)[out_tidx.w % 4];
49
50    // Read weight tensor for embedding.
51    const ivec3 weight_lpos = ivec3(out_tidx.x, in_texel_elem, 0);
52    out_texel[i] = load_texel_lpos(t_weight, weight_lpos, weight_axis_map).x;
53  }
54
55  write_texel_lpos(t_out, out_lpos, out_texel, out_axis_map);
56}
57