1/* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9#version 450 core 10 11#define PRECISION ${PRECISION} 12 13#define VEC4_T ${texel_type(DTYPE)} 14 15#define op(X, A, B) ${OPERATOR} 16 17#include "indexing_utils.h" 18 19layout(std430) buffer; 20 21${layout_declare_tensor(0, "w", "t_out", DTYPE, "texture3d")} 22${layout_declare_tensor(1, "r", "t_in", DTYPE, "texture3d")} 23${layout_declare_tensor(2, "r", "t_kernel", DTYPE, "texture2d")} 24${layout_declare_tensor(3, "r", "t_bias", DTYPE, "texture2d")} 25${layout_declare_ubo(4, "ivec3", "out_limits")} 26${layout_declare_ubo(5, "ivec4", "in_sizes")} 27${layout_declare_ubo(6, "ivec2", "kernel_size", "ivec2", "stride", "ivec2", "padding", "ivec2", "dilation")} 28${layout_declare_ubo(7, "ivec2", "overlay_region", "int", "in_group_size")} 29${layout_declare_ubo(8, "float", "out_min", "float", "out_max")} 30 31layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; 32 33layout(constant_id = 3) const int packed_dim = C_DIM; 34 35/* 36 * Computes a 2D transpose convolution. Each shader invocation calculates the 37 * output at a single output location. For details, refer to conv2d.glsl which 38 * uses a similar approach. 39 */ 40void main() { 41 const ivec3 pos = ivec3(gl_GlobalInvocationID); 42 43 if (any(greaterThanEqual(pos, out_limits))) { 44 return; 45 } 46 47 ivec2 ipos = pos.xy + padding; 48 49 const ivec2 start = max( 50 ivec2(0), 51 ivec2(ceil((vec2(ipos) - kernel_size + 1) / vec2(stride)))); 52 const ivec2 end = 53 min(ivec2(in_sizes.xy), 54 ivec2(floor(vec2(ipos) / vec2(stride))) + 1); 55 56 const int ic = in_group_size; 57 const int kx_stride = ic * (stride.x - 1); 58 59 int ky_start = overlay_region.y - 1 - (ipos.y - stride.y * start.y) + pos.z * kernel_size.y; 60 int kx_start = (overlay_region.x - 1 - (ipos.x - stride.x * start.x)) * ic; 61 62 VEC4_T sum = texelFetch(t_bias, ivec2(pos.z, 0), 0); 63 for (int y = start.y, ky = ky_start; y < end.y; ++y, ky += stride.y) { 64 for (int x = start.x, kx = kx_start; x < end.x; ++x, kx += kx_stride) { 65 for (int z4 = 0; z4 < ic / 4; ++z4, kx += 4) { 66 const VEC4_T in_texel = texelFetch(t_in, ivec3(x, y, z4), 0); 67 const ivec4 kxs = kx + ivec4(0, 1, 2, 3); 68 69 sum = fma(in_texel.xxxx, texelFetch(t_kernel, ivec2(kxs.x, ky), 0), sum); 70 sum = fma(in_texel.yyyy, texelFetch(t_kernel, ivec2(kxs.y, ky), 0), sum); 71 sum = fma(in_texel.zzzz, texelFetch(t_kernel, ivec2(kxs.z, ky), 0), sum); 72 sum = fma(in_texel.wwww, texelFetch(t_kernel, ivec2(kxs.w, ky), 0), sum); 73 } 74 } 75 } 76 77 imageStore(t_out, pos, op(sum, out_min, out_max)); 78} 79