1/* 2 * Copyright (c) Meta Platforms, Inc. and affiliates. 3 * All rights reserved. 4 * 5 * This source code is licensed under the BSD-style license found in the 6 * LICENSE file in the root directory of this source tree. 7 */ 8 9#version 450 core 10 11#define PRECISION ${PRECISION} 12 13#define VEC4_T ${texel_type(DTYPE)} 14 15#define op(X, A, B) ${OPERATOR} 16 17layout(std430) buffer; 18 19${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)} 20${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)} 21${layout_declare_tensor(B, "r", "kernel_in", DTYPE, STORAGE)} 22${layout_declare_tensor(B, "r", "bias_in", DTYPE, STORAGE)} 23 24${layout_declare_ubo(B, "ivec3", "out_limits")} 25${layout_declare_ubo(B, "ivec4", "in_sizes")} 26 27${layout_declare_ubo(B,"int", "kernel_size", "int", "stride", "int", "padding", "int", "dilation", "int", "in_group_size", "int", "out_group_size")} 28 29${layout_declare_ubo(B, "float", "out_min", "float", "out_max")} 30 31#include "indexing_utils.h" 32 33layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; 34 35${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")} 36const lowp ivec4 out_axis_map = unhash_axis_map(out_layout); 37 38${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")} 39const lowp ivec4 in_axis_map = unhash_axis_map(in_layout); 40 41${layout_declare_spec_const(C, "int", "kernel_layout", "DEFAULT_LAYOUT")} 42const lowp ivec4 kernel_axis_map = unhash_axis_map(kernel_layout); 43 44${layout_declare_spec_const(C, "int", "bias_layout", "DEFAULT_LAYOUT")} 45const lowp ivec4 bias_axis_map = unhash_axis_map(bias_layout); 46 47// Let us define 48// 49// input = (N, in_C, in_L), 50// output = (N, out_C, out_L), 51// groups = G, 52// kernel = K, 53// 54// which results in shapes 55// 56// weight = (out_C, in_C / G, K), 57// bias = (out_C,). 58// 59// This implementation performs out_C shader invocations, where each invocation 60// calculates the rolling kernel of the length dimension for each batch, i.e., 61// computes out_L * N results. 62// 63// Note that we can rewrite this implementation as out_L * out_C * ceil(N / 4) 64// shader invocations, where each invocation computes 1 result. But that 65// performs worse. 66void main() { 67 const ivec3 lpos = ivec3(gl_GlobalInvocationID); 68 69 if (any(greaterThanEqual(lpos, out_limits))) { 70 return; 71 } 72 73 int in_length = in_sizes.x; 74 int batch_size = in_sizes.z; 75 76 // "out_c" is the output's channel index where we write our result. 77 // Across shader invocations, this is the only value that varies. 78 int out_c = lpos.y; 79 VEC4_T bias = load_texel_lpos(bias_in, ivec3(out_c, 0, 0), bias_axis_map); 80 81 // "in_c" tracks the input's channel start index. 82 // We iterate over the input group that corresponds to the output group. 83 int c_start = (out_c / out_group_size) * in_group_size; 84 int c_end = c_start + in_group_size; 85 86 // "in_l" tracks the input's length start index for our input-kernel overlay 87 // region. 88 int l_start = -padding; 89 int l_end = in_length + padding - dilation * (kernel_size - 1); 90 91 // Since the input/output tensors are channel-packed, which is along the 92 // batch dimension, we can batch-read/write four elements at a time. 93 for (int n = 0; n < batch_size; n += 4) { 94 // "out_l" tracks the output's length index where we write our result. 95 int out_l = 0; 96 97 for (int in_l = l_start; in_l < l_end; in_l += stride, ++out_l) { 98 VEC4_T sum = VEC4_T(0); 99 100 for (int in_c = c_start; in_c < c_end; ++in_c) { 101 // "k" tracks the kernel's index for our input-kernel computation. 102 // It reads out-of-bound zeros, but trying to avoid them complicates 103 // for-loop conditions, which results in worse performance. 104 for (int k = 0; k < kernel_size; k += 4) { 105 // Since the weight tensor is width-packed, which is along the length 106 // dimension, we can batch-read four elements at a time. 107 const ivec3 w_lpos = ivec3(k / 4, in_c % in_group_size, out_c); 108 const VEC4_T weight = load_texel_lpos(kernel_in, w_lpos, kernel_axis_map); 109 110 ivec3 in_pos = lpos_to_pos(ivec3(in_l + k * dilation, in_c, n / 4), in_axis_map); 111 sum = fma(weight.xxxx, load_texel(t_in, in_pos), sum); 112 113 in_pos[in_axis_map.x] += dilation; 114 sum = fma(weight.yyyy, load_texel(t_in, in_pos), sum); 115 116 in_pos[in_axis_map.x] += dilation; 117 sum = fma(weight.zzzz, load_texel(t_in, in_pos), sum); 118 119 in_pos[in_axis_map.x] += dilation; 120 sum = fma(weight.wwww, load_texel(t_in, in_pos), sum); 121 } 122 } 123 124 const ivec3 out_lpos = ivec3(out_l, out_c, n / 4); 125 write_texel_lpos(t_out, out_lpos, op(sum + bias.x, out_min, out_max), out_axis_map); 126 } 127 } 128} 129