• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#version 450 core
10
11#define PRECISION ${PRECISION}
12
13#define VEC4_T ${texel_type(DTYPE)}
14
15#define TILE_SIZE ${TILE_SIZE}
16
17#define op(X, A, B) ${OPERATOR}
18
19#include "indexing_utils.h"
20
21layout(std430) buffer;
22
23${layout_declare_tensor(0, "w", "t_out", DTYPE, "texture3d")}
24${layout_declare_tensor(1, "r", "t_in", DTYPE, "texture3d")}
25${layout_declare_tensor(2, "r", "t_kernel", DTYPE, "texture2d")}
26${layout_declare_tensor(3, "r", "t_bias", DTYPE, "texture2d")}
27${layout_declare_ubo(4, "ivec3", "out_limits")}
28${layout_declare_ubo(5, "ivec4", "in_sizes")}
29${layout_declare_ubo(6, "ivec2", "kernel_size", "ivec2", "stride", "ivec2", "padding", "ivec2", "dilation")}
30${layout_declare_ubo(7, "ivec2", "overlay_region", "int", "in_group_size")}
31${layout_declare_ubo(8, "float", "out_min", "float", "out_max")}
32
33layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
34
35#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require
36
37/*
38 * Computes a 2D pointwise convolution of an NxN output tile. Calculating an
39 * output tile for pointwise convolution is more efficient because the kernel
40 * size is only 1x1, making it easier to re-use loaded texels from t_kernel.
41 */
42void main() {
43  const u16vec3 gpos = u16vec3(gl_GlobalInvocationID);
44
45  // Output position for TILE_SIZE = 2
46  // +--------+--------+
47  // | pos[0] | pos[1] |
48  // +--------+--------+
49  // | pos[2] | pos[3] |
50  // +--------+--------+
51  u16vec3 pos[TILE_SIZE * TILE_SIZE];
52  for (int y = 0, i = 0; y < TILE_SIZE; ++y) {
53    for (int x = 0; x < TILE_SIZE; ++x) {
54      pos[i] = u16vec3(
55          gpos.x * TILE_SIZE + x, gpos.y * TILE_SIZE + y, gpos.z);
56      i++;
57    }
58  }
59
60  // If the top left position is out of bounds, then this invocation will have
61  // no work to do.
62  if (any(greaterThanEqual(pos[0], out_limits))) {
63    return;
64  }
65
66  // Compute the index of the input texture that needs to be loaded for each
67  // output position. Note that negative indices can be produced indicating that
68  // the top-left element is in a region added by padding.
69  u16vec2 ipos[TILE_SIZE * TILE_SIZE];
70  for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
71    ipos[i] = pos[i].xy * u16vec2(stride) - u16vec2(padding);
72  }
73
74  vec4 sum[TILE_SIZE * TILE_SIZE];
75  sum[0] = texelFetch(t_bias, u16vec2(gpos.z, 0), 0);
76  for (int i = 1; i < TILE_SIZE * TILE_SIZE; ++i) {
77    sum[i] = sum[0];
78  }
79
80  int z4 = 0;
81  // Since the kernel is 1x1, we only have to loop over the depth dimension.
82  for (uint16_t z = uint16_t(0); z < uint16_t(in_group_size); z += uint16_t(4), ++z4) {
83    // During prepacking, the weight tensor has been permuted so that the
84    // channel (IC) dim is along the x-axis, and the batch (OC) dim is along
85    // the z-axis.
86    const vec4 ktex_0 = texelFetchOffset(t_kernel, u16vec2(z, gpos.z), 0, u16vec2(0, 0));
87    const vec4 ktex_1 = texelFetchOffset(t_kernel, u16vec2(z, gpos.z), 0, u16vec2(1, 0));
88    const vec4 ktex_2 = texelFetchOffset(t_kernel, u16vec2(z, gpos.z), 0, u16vec2(2, 0));
89    const vec4 ktex_3 = texelFetchOffset(t_kernel, u16vec2(z, gpos.z), 0, u16vec2(3, 0));
90
91
92#pragma unroll
93    for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
94      const vec4 in_tex = texelFetch(t_in, u16vec3(ipos[i], z4), 0);
95      // For 2x2 tile size algorithm works as follows.
96      // To explain the calculations below, the contents of one in_tex and the
97      // group of 4 texels loaded from t_kernel are shown:
98      //
99      //   in_tex                 t_kernel
100      //    -x->                   ---x--->
101      //   +---+              +----+----+----+----+
102      // ^ | w |           ^  | D0 | D1 | D2 | D3 |
103      // | +---+           |  +----+----+----+----+
104      // | | z |           |  | C0 | C1 | C2 | C3 |
105      // z +---+           z  +----+----+----+----+
106      // | | y |           |  | B0 | B2 | B2 | B3 |
107      // | +---+           |  +----+----+----+----+
108      //   | x |              | A0 | A1 | A2 | A3 |
109      //   +---+              +----+----+----+----+
110      //
111      // In the t_kernel graphic, cells sharing the same letter are from
112      // the same batch/output channel index, and the number denotes a unique
113      // channel index. To calculate the output texel, the following
114      // calculation is performed:
115      //
116      //  +---+ +----+   +---+ +----+   +---+ +----+   +---+ +----+
117      //  | x | | D0 |   | y | | D1 |   | z | | D2 |   | w | | D3 |
118      //  +---+ +----+   +---+ +----+   +---+ +----+   +---+ +----+
119      //  | x | | C0 |   | y | | C1 |   | z | | C2 |   | w | | C3 |
120      //  +---+X+----+ + +---+X+----+ + +---+X+----+ + +---+X+----+
121      //  | x | | B0 |   | y | | B1 |   | z | | B2 |   | w | | B3 |
122      //  +---+ +----+   +---+ +----+   +---+ +----+   +---+ +----+
123      //  | x | | A0 |   | y | | A1 |   | z | | A2 |   | w | | A3 |
124      //  +---+ +----+   +---+ +----+   +---+ +----+   +---+ +----+
125      //
126      //  which is what is expressed in the following calculations. This is done
127      //  for each output position.
128      sum[i] = fma(in_tex.xxxx, ktex_0, sum[i]);
129      sum[i] = fma(in_tex.yyyy, ktex_1, sum[i]);
130      sum[i] = fma(in_tex.zzzz, ktex_2, sum[i]);
131      sum[i] = fma(in_tex.wwww, ktex_3, sum[i]);
132    }
133  }
134
135  for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) {
136    if (all(lessThan(pos[i], out_limits))) {
137      imageStore(t_out, pos[i], op(sum[i], out_min, out_max));
138    }
139  }
140}
141