1/* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9#version 450 core 10 11#define PRECISION ${PRECISION} 12 13#define VEC4_T ${texel_load_type(DTYPE, STORAGE)} 14#define T ${buffer_scalar_type(DTYPE)} 15 16#define op(X, A, B) ${OPERATOR} 17 18${define_active_storage_type(STORAGE)} 19 20#include "indexing_utils.h" 21 22${define_required_extensions(DTYPE)} 23 24layout(std430) buffer; 25 26${layout_declare_tensor(0, "w", "t_out", DTYPE, STORAGE)} 27${layout_declare_tensor(1, "r", "t_in", DTYPE, STORAGE)} 28$if STORAGE == "buffer": 29 ${layout_declare_ubo(2, "int", "numel")} 30$else: 31 ${layout_declare_ubo(2, "ivec3", "out_limits")} 32${layout_declare_ubo(3, "float", "minimum")} 33${layout_declare_ubo(4, "float", "maximum")} 34 35layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; 36 37#include "activations.h" 38 39#ifdef USING_BUFFER 40 41void main() { 42 const int i = int(gl_GlobalInvocationID.x); 43 if (i >= numel) { 44 return; 45 } 46 47 float in_val = float(t_in[i]); 48 t_out[i] = T(op(in_val, minimum, maximum)); 49} 50 51#else 52 53void main() { 54 const ivec3 pos = ivec3(gl_GlobalInvocationID); 55 56 if (any(greaterThanEqual(pos, out_limits))) { 57 return; 58 } 59 60 VEC4_T in_texel = texelFetch(t_in, pos, 0); 61 imageStore(t_out, pos, VEC4_T(op(in_texel, minimum, maximum))); 62} 63 64#endif 65