• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1/*
2 * Copyright (c) Meta Platforms, Inc. and affiliates.
3 * All rights reserved.
4 *
5 * This source code is licensed under the BSD-style license found in the
6 * LICENSE file in the root directory of this source tree.
7 */
8
9#version 450 core
10
11#define PRECISION ${PRECISION}
12
13#define VEC4_T ${texel_type(DTYPE)}
14
15#define op(X, A, B) ${OPERATOR}
16
17layout(std430) buffer;
18
19${layout_declare_tensor(B, "w", "t_out", DTYPE, STORAGE)}
20${layout_declare_tensor(B, "r", "t_in", DTYPE, STORAGE)}
21${layout_declare_tensor(B, "r", "kernel_in", DTYPE, STORAGE)}
22${layout_declare_tensor(B, "r", "bias_in", DTYPE, STORAGE)}
23
24${layout_declare_ubo(B, "ivec3", "out_limits")}
25${layout_declare_ubo(B, "ivec4", "in_sizes")}
26
27${layout_declare_ubo(B,"int", "kernel_size", "int", "stride", "int", "padding", "int", "dilation", "int", "in_group_size", "int", "out_group_size")}
28
29${layout_declare_ubo(B, "float", "out_min", "float", "out_max")}
30
31#include "indexing_utils.h"
32
33layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in;
34
35${layout_declare_spec_const(C, "int", "out_layout", "DEFAULT_LAYOUT")}
36const lowp ivec4 out_axis_map = unhash_axis_map(out_layout);
37
38${layout_declare_spec_const(C, "int", "in_layout", "DEFAULT_LAYOUT")}
39const lowp ivec4 in_axis_map = unhash_axis_map(in_layout);
40
41${layout_declare_spec_const(C, "int", "kernel_layout", "DEFAULT_LAYOUT")}
42const lowp ivec4 kernel_axis_map = unhash_axis_map(kernel_layout);
43
44${layout_declare_spec_const(C, "int", "bias_layout", "DEFAULT_LAYOUT")}
45const lowp ivec4 bias_axis_map = unhash_axis_map(bias_layout);
46
47// Let us define
48//
49// input = (N, in_C, in_L),
50// output = (N, out_C, out_L),
51// groups = G,
52// kernel = K,
53//
54// which results in shapes
55//
56// weight = (out_C, in_C / G, K),
57// bias = (out_C,).
58//
59// This implementation performs out_C shader invocations, where each invocation
60// calculates the rolling kernel of the length dimension for each batch, i.e.,
61// computes out_L * N results.
62//
63// Note that we can rewrite this implementation as out_L * out_C * ceil(N / 4)
64// shader invocations, where each invocation computes 1 result. But that
65// performs worse.
66void main() {
67  const ivec3 lpos = ivec3(gl_GlobalInvocationID);
68
69  if (any(greaterThanEqual(lpos, out_limits))) {
70    return;
71  }
72
73  int in_length = in_sizes.x;
74  int batch_size = in_sizes.z;
75
76  // "out_c" is the output's channel index where we write our result.
77  // Across shader invocations, this is the only value that varies.
78  int out_c = lpos.y;
79  VEC4_T bias = load_texel_lpos(bias_in, ivec3(out_c, 0, 0), bias_axis_map);
80
81  // "in_c" tracks the input's channel start index.
82  // We iterate over the input group that corresponds to the output group.
83  int c_start = (out_c / out_group_size) * in_group_size;
84  int c_end = c_start + in_group_size;
85
86  // "in_l" tracks the input's length start index for our input-kernel overlay
87  // region.
88  int l_start = -padding;
89  int l_end = in_length + padding - dilation * (kernel_size - 1);
90
91  // Since the input/output tensors are channel-packed, which is along the
92  // batch dimension, we can batch-read/write four elements at a time.
93  for (int n = 0; n < batch_size; n += 4) {
94    // "out_l" tracks the output's length index where we write our result.
95    int out_l = 0;
96
97    for (int in_l = l_start; in_l < l_end; in_l += stride, ++out_l) {
98      VEC4_T sum = VEC4_T(0);
99
100      for (int in_c = c_start; in_c < c_end; ++in_c) {
101        // "k" tracks the kernel's index for our input-kernel computation.
102        // It reads out-of-bound zeros, but trying to avoid them complicates
103        // for-loop conditions, which results in worse performance.
104        for (int k = 0; k < kernel_size; k += 4) {
105          // Since the weight tensor is width-packed, which is along the length
106          // dimension, we can batch-read four elements at a time.
107          const ivec3 w_lpos = ivec3(k / 4, in_c % in_group_size, out_c);
108          const VEC4_T weight = load_texel_lpos(kernel_in, w_lpos, kernel_axis_map);
109
110          ivec3 in_pos = lpos_to_pos(ivec3(in_l + k * dilation, in_c, n / 4), in_axis_map);
111          sum = fma(weight.xxxx, load_texel(t_in, in_pos), sum);
112
113          in_pos[in_axis_map.x] += dilation;
114          sum = fma(weight.yyyy, load_texel(t_in, in_pos), sum);
115
116          in_pos[in_axis_map.x] += dilation;
117          sum = fma(weight.zzzz, load_texel(t_in, in_pos), sum);
118
119          in_pos[in_axis_map.x] += dilation;
120          sum = fma(weight.wwww, load_texel(t_in, in_pos), sum);
121        }
122      }
123
124      const ivec3 out_lpos = ivec3(out_l, out_c, n / 4);
125      write_texel_lpos(t_out, out_lpos, op(sum + bias.x, out_min, out_max), out_axis_map);
126    }
127  }
128}
129