• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#version 450 core
10
11#define PRECISION ${PRECISION}
12
13#define VEC4_T ${texel_type(DTYPE)}
14
15#define POS ${get_pos[NDIM]("pos")}
16
17#include "indexing_utils.h"
18
19layout(std430) buffer;
20
21layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out;
22
23layout(set = 0, binding = 1) uniform PRECISION restrict Sizes {
24  ivec4 sizes;
25};
26
27layout(set = 0, binding = 2) uniform PRECISION restrict FillVal {
28  float fill_value;
29};
30
31layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
32
33layout(constant_id = 3) const int packed_dim = C_DIM;
34
35void main() {
36  const ivec3 pos = ivec3(gl_GlobalInvocationID);
37  const ivec4 idx = to_tensor_idx(pos, sizes, packed_dim);
38
39  if (any(greaterThanEqual(idx, sizes))) {
40    return;
41  }
42
43  VEC4_T outtex = VEC4_T(fill_value);
44  const int packed_dim_size = sizes[packed_dim];
45  int packed_idx = idx[packed_dim];
46
47  if (packed_idx + 3 >= packed_dim_size) {
48    ivec4 packed_ind = ivec4(packed_idx) + ivec4(0, 1, 2, 3);
49    VEC4_T valid_idx = VEC4_T(lessThan(packed_ind, ivec4(packed_dim_size)));
50    outtex = outtex * valid_idx;
51  }
52
53  imageStore(image_out, POS, outtex);
54}
55