• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#version 450 core
10
11#define PRECISION ${PRECISION}
12
13#define VEC4_T ${texel_type(DTYPE)}
14
15#define op(X, A, B) ${OPERATOR}
16
17#include "indexing_utils.h"
18
19layout(std430) buffer;
20
21${layout_declare_tensor(0, "w", "t_out", DTYPE, "texture3d")}
22${layout_declare_tensor(1, "r", "t_in", DTYPE, "texture3d")}
23${layout_declare_tensor(2, "r", "t_kernel", DTYPE, "texture2d")}
24${layout_declare_tensor(3, "r", "t_bias", DTYPE, "texture2d")}
25${layout_declare_ubo(4, "ivec3", "out_limits")}
26${layout_declare_ubo(5, "ivec4", "in_sizes")}
27${layout_declare_ubo(6, "ivec2", "kernel_size", "ivec2", "stride", "ivec2", "padding", "ivec2", "dilation")}
28${layout_declare_ubo(7, "ivec2", "overlay_region", "int", "in_group_size")}
29${layout_declare_ubo(8, "float", "out_min", "float", "out_max")}
30
31layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
32
33layout(constant_id = 3) const int packed_dim = C_DIM;
34
35/*
36 * Computes a 2D transpose convolution. Each shader invocation calculates the
37 * output at a single output location. For details, refer to conv2d.glsl which
38 * uses a similar approach.
39 */
40void main() {
41  const ivec3 pos = ivec3(gl_GlobalInvocationID);
42
43  if (any(greaterThanEqual(pos, out_limits))) {
44    return;
45  }
46
47  ivec2 ipos = pos.xy + padding;
48
49  const ivec2 start = max(
50      ivec2(0),
51      ivec2(ceil((vec2(ipos) - kernel_size + 1) / vec2(stride))));
52  const ivec2 end =
53      min(ivec2(in_sizes.xy),
54          ivec2(floor(vec2(ipos) / vec2(stride))) + 1);
55
56  const int ic = in_group_size;
57  const int kx_stride = ic * (stride.x - 1);
58
59  int ky_start = overlay_region.y - 1 - (ipos.y - stride.y * start.y) + pos.z * kernel_size.y;
60  int kx_start = (overlay_region.x - 1 - (ipos.x - stride.x * start.x)) * ic;
61
62  VEC4_T sum = texelFetch(t_bias, ivec2(pos.z, 0), 0);
63  for (int y = start.y, ky = ky_start; y < end.y; ++y, ky += stride.y) {
64    for (int x = start.x, kx = kx_start; x < end.x; ++x, kx += kx_stride) {
65      for (int z4 = 0; z4 < ic / 4; ++z4, kx += 4) {
66        const VEC4_T in_texel = texelFetch(t_in, ivec3(x, y, z4), 0);
67        const ivec4 kxs = kx + ivec4(0, 1, 2, 3);
68
69        sum = fma(in_texel.xxxx, texelFetch(t_kernel, ivec2(kxs.x, ky), 0), sum);
70        sum = fma(in_texel.yyyy, texelFetch(t_kernel, ivec2(kxs.y, ky), 0), sum);
71        sum = fma(in_texel.zzzz, texelFetch(t_kernel, ivec2(kxs.z, ky), 0), sum);
72        sum = fma(in_texel.wwww, texelFetch(t_kernel, ivec2(kxs.w, ky), 0), sum);
73      }
74    }
75  }
76
77  imageStore(t_out, pos, op(sum, out_min, out_max));
78}
79