• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#version 450 core
10
11#define PRECISION ${PRECISION}
12
13#define VEC4_T ${texel_load_type(DTYPE, STORAGE)}
14#define T ${buffer_scalar_type(DTYPE)}
15
16#define op(X, A, B) ${OPERATOR}
17
18${define_active_storage_type(STORAGE)}
19
20#include "indexing_utils.h"
21
22${define_required_extensions(DTYPE)}
23
24layout(std430) buffer;
25
26${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)}
27${layout_declare_tensor(1, "r", "t_in", DTYPE, STORAGE)}
28$if STORAGE == "buffer":
29  ${layout_declare_ubo(2, "int", "numel")}
30$else:
31  ${layout_declare_ubo(2, "ivec3", "out_limits")}
32${layout_declare_ubo(3, "float", "minimum")}
33${layout_declare_ubo(4, "float", "maximum")}
34
35layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
36
37#include "activations.h"
38
39#ifdef USING_BUFFER
40
41void main() {
42  const int i = int(gl_GlobalInvocationID.x);
43  if (i >= numel) {
44    return;
45  }
46
47  float in_val = float(t_in[i]);
48  t_out[i] = T(op(in_val, minimum, maximum));
49}
50
51#else
52
53void main() {
54  const ivec3 pos = ivec3(gl_GlobalInvocationID);
55
56  if (any(greaterThanEqual(pos, out_limits))) {
57    return;
58  }
59
60  VEC4_T in_texel = texelFetch(t_in, pos, 0);
61  imageStore(t_out, pos, VEC4_T(op(in_texel, minimum, maximum)));
62}
63
64#endif
65