1/* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9#version 450 core 10 11#define PRECISION ${PRECISION} 12 13#define VEC4_T ${texel_type(DTYPE)} 14 15#define TILE_SIZE ${TILE_SIZE} 16 17#define op(X, A, B) ${OPERATOR} 18 19#include "indexing_utils.h" 20 21layout(std430) buffer; 22 23${layout_declare_tensor(0, "w", "t_out", DTYPE, "texture3d")} 24${layout_declare_tensor(1, "r", "t_in", DTYPE, "texture3d")} 25${layout_declare_tensor(2, "r", "t_kernel", DTYPE, "texture2d")} 26${layout_declare_tensor(3, "r", "t_bias", DTYPE, "texture2d")} 27${layout_declare_ubo(4, "ivec3", "out_limits")} 28${layout_declare_ubo(5, "ivec4", "in_sizes")} 29${layout_declare_ubo(6, "ivec2", "kernel_size", "ivec2", "stride", "ivec2", "padding", "ivec2", "dilation")} 30${layout_declare_ubo(7, "ivec2", "overlay_region", "int", "in_group_size")} 31${layout_declare_ubo(8, "float", "out_min", "float", "out_max")} 32 33layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; 34 35#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require 36 37/* 38 * Computes a 2D pointwise convolution of an NxN output tile. Calculating an 39 * output tile for pointwise convolution is more efficient because the kernel 40 * size is only 1x1, making it easier to re-use loaded texels from t_kernel. 41 */ 42void main() { 43 const u16vec3 gpos = u16vec3(gl_GlobalInvocationID); 44 45 // Output position for TILE_SIZE = 2 46 // +--------+--------+ 47 // | pos[0] | pos[1] | 48 // +--------+--------+ 49 // | pos[2] | pos[3] | 50 // +--------+--------+ 51 u16vec3 pos[TILE_SIZE * TILE_SIZE]; 52 for (int y = 0, i = 0; y < TILE_SIZE; ++y) { 53 for (int x = 0; x < TILE_SIZE; ++x) { 54 pos[i] = u16vec3( 55 gpos.x * TILE_SIZE + x, gpos.y * TILE_SIZE + y, gpos.z); 56 i++; 57 } 58 } 59 60 // If the top left position is out of bounds, then this invocation will have 61 // no work to do. 62 if (any(greaterThanEqual(pos[0], out_limits))) { 63 return; 64 } 65 66 // Compute the index of the input texture that needs to be loaded for each 67 // output position. Note that negative indices can be produced indicating that 68 // the top-left element is in a region added by padding. 69 u16vec2 ipos[TILE_SIZE * TILE_SIZE]; 70 for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) { 71 ipos[i] = pos[i].xy * u16vec2(stride) - u16vec2(padding); 72 } 73 74 vec4 sum[TILE_SIZE * TILE_SIZE]; 75 sum[0] = texelFetch(t_bias, u16vec2(gpos.z, 0), 0); 76 for (int i = 1; i < TILE_SIZE * TILE_SIZE; ++i) { 77 sum[i] = sum[0]; 78 } 79 80 int z4 = 0; 81 // Since the kernel is 1x1, we only have to loop over the depth dimension. 82 for (uint16_t z = uint16_t(0); z < uint16_t(in_group_size); z += uint16_t(4), ++z4) { 83 // During prepacking, the weight tensor has been permuted so that the 84 // channel (IC) dim is along the x-axis, and the batch (OC) dim is along 85 // the z-axis. 86 const vec4 ktex_0 = texelFetchOffset(t_kernel, u16vec2(z, gpos.z), 0, u16vec2(0, 0)); 87 const vec4 ktex_1 = texelFetchOffset(t_kernel, u16vec2(z, gpos.z), 0, u16vec2(1, 0)); 88 const vec4 ktex_2 = texelFetchOffset(t_kernel, u16vec2(z, gpos.z), 0, u16vec2(2, 0)); 89 const vec4 ktex_3 = texelFetchOffset(t_kernel, u16vec2(z, gpos.z), 0, u16vec2(3, 0)); 90 91 92#pragma unroll 93 for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) { 94 const vec4 in_tex = texelFetch(t_in, u16vec3(ipos[i], z4), 0); 95 // For 2x2 tile size algorithm works as follows. 96 // To explain the calculations below, the contents of one in_tex and the 97 // group of 4 texels loaded from t_kernel are shown: 98 // 99 // in_tex t_kernel 100 // -x-> ---x---> 101 // +---+ +----+----+----+----+ 102 // ^ | w | ^ | D0 | D1 | D2 | D3 | 103 // | +---+ | +----+----+----+----+ 104 // | | z | | | C0 | C1 | C2 | C3 | 105 // z +---+ z +----+----+----+----+ 106 // | | y | | | B0 | B2 | B2 | B3 | 107 // | +---+ | +----+----+----+----+ 108 // | x | | A0 | A1 | A2 | A3 | 109 // +---+ +----+----+----+----+ 110 // 111 // In the t_kernel graphic, cells sharing the same letter are from 112 // the same batch/output channel index, and the number denotes a unique 113 // channel index. To calculate the output texel, the following 114 // calculation is performed: 115 // 116 // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ 117 // | x | | D0 | | y | | D1 | | z | | D2 | | w | | D3 | 118 // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ 119 // | x | | C0 | | y | | C1 | | z | | C2 | | w | | C3 | 120 // +---+X+----+ + +---+X+----+ + +---+X+----+ + +---+X+----+ 121 // | x | | B0 | | y | | B1 | | z | | B2 | | w | | B3 | 122 // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ 123 // | x | | A0 | | y | | A1 | | z | | A2 | | w | | A3 | 124 // +---+ +----+ +---+ +----+ +---+ +----+ +---+ +----+ 125 // 126 // which is what is expressed in the following calculations. This is done 127 // for each output position. 128 sum[i] = fma(in_tex.xxxx, ktex_0, sum[i]); 129 sum[i] = fma(in_tex.yyyy, ktex_1, sum[i]); 130 sum[i] = fma(in_tex.zzzz, ktex_2, sum[i]); 131 sum[i] = fma(in_tex.wwww, ktex_3, sum[i]); 132 } 133 } 134 135 for (int i = 0; i < TILE_SIZE * TILE_SIZE; ++i) { 136 if (all(lessThan(pos[i], out_limits))) { 137 imageStore(t_out, pos[i], op(sum[i], out_min, out_max)); 138 } 139 } 140} 141