• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#version 450 core
10
11#define PRECISION ${PRECISION}
12
13#define VEC4_T ${texel_type(DTYPE)}
14
15#define op(X, Y, A) ${OPERATOR}
16
17layout(std430) buffer;
18
19${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
20${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
21${layout_declare_tensor(B, "r", "t_other", DTYPE, STORAGE)}
22${layout_declare_ubo(B, "ivec4", "out_sizes")}
23${layout_declare_ubo(B, "ivec4", "in_sizes")}
24${layout_declare_ubo(B, "ivec4", "other_sizes")}
25${layout_declare_ubo(B, "ivec2", "broadcast_params")}
26${layout_declare_ubo(B, "float", "alpha")}
27
28#include "broadcasting_utils.h"
29#include "indexing_utils.h"
30
31layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
32
33${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")}
34const lowp ivec4 out_axis_map = unhash_axis_map(out_layout);
35const lowp int packed_dim = unhash_packed_dim(out_layout);
36
37${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")}
38const lowp ivec4 in_axis_map = unhash_axis_map(in_layout);
39
40${layout_declare_spec_const(C, "int", "other_layout", "DEFAULT_LAYOUT")}
41const lowp ivec4 other_axis_map = unhash_axis_map(other_layout);
42
43void main() {
44  const ivec3 lpos = ivec3(gl_GlobalInvocationID);
45  const ivec4 tidx = lpos_to_tidx(lpos, out_sizes, out_axis_map.w, packed_dim);
46
47  if (any(greaterThanEqual(tidx, out_sizes))) {
48    return;
49  }
50
51  // broadcast on logical sizes
52  ivec4 in_idx = broadcast_indices(tidx, in_sizes);
53  VEC4_T in_texel = VEC4_T(load_texel(
54    t_in,
55    // read axis mapped texel
56    tidx_to_pos(in_idx, in_sizes, in_axis_map, packed_dim)));
57
58  // broadcast on logical sizes
59  ivec4 other_idx = broadcast_indices(tidx, other_sizes);
60  VEC4_T other_texel = VEC4_T(load_texel(
61    t_other,
62    // read axis mapped texel
63    tidx_to_pos(other_idx, other_sizes, other_axis_map, packed_dim)));
64
65  // Check boolean broadcast flags; we use ivec2 instead of bvec2 for alignment.
66  if (broadcast_params.x > 0) {
67    in_texel = in_texel.xxxx;
68  }
69  if (broadcast_params.y > 0) {
70    other_texel = other_texel.xxxx;
71  }
72
73  write_texel_lpos(
74    t_out,
75    lpos,
76    VEC4_T(op(in_texel, other_texel, alpha)),
77    out_axis_map);
78}
79