1/* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9#version 450 core 10 11#define PRECISION ${PRECISION} 12 13#define VEC4_T ${texel_type(DTYPE)} 14 15#define POS ${get_pos[NDIM]("pos")} 16 17#include "indexing_utils.h" 18 19layout(std430) buffer; 20 21layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly ${IMAGE_T[NDIM][DTYPE]} image_out; 22 23layout(set = 0, binding = 1) uniform PRECISION restrict Sizes { 24 ivec4 sizes; 25}; 26 27layout(set = 0, binding = 2) uniform PRECISION restrict FillVal { 28 float fill_value; 29}; 30 31layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; 32 33layout(constant_id = 3) const int packed_dim = C_DIM; 34 35void main() { 36 const ivec3 pos = ivec3(gl_GlobalInvocationID); 37 const ivec4 idx = to_tensor_idx(pos, sizes, packed_dim); 38 39 if (any(greaterThanEqual(idx, sizes))) { 40 return; 41 } 42 43 VEC4_T outtex = VEC4_T(fill_value); 44 const int packed_dim_size = sizes[packed_dim]; 45 int packed_idx = idx[packed_dim]; 46 47 if (packed_idx + 3 >= packed_dim_size) { 48 ivec4 packed_ind = ivec4(packed_idx) + ivec4(0, 1, 2, 3); 49 VEC4_T valid_idx = VEC4_T(lessThan(packed_ind, ivec4(packed_dim_size))); 50 outtex = outtex * valid_idx; 51 } 52 53 imageStore(image_out, POS, outtex); 54} 55