• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1 /*
2  * Computes a 4D tensor co-ordinate from a linearized index
3  */
idx_to_coord(const uint idx,const uvec4 strides,const uvec4 sizes)4 uvec4 idx_to_coord(const uint idx, const uvec4 strides, const uvec4 sizes) {
5   return ivec4(mod(idx / strides, sizes));
6 }
7 
8 /*
9  * Computes a linearized index from a 4D tensor co-ordinate
10  */
coord_to_idx(const uvec4 coord,const uvec4 strides)11 uint coord_to_idx(const uvec4 coord, const uvec4 strides) {
12   return int(dot(coord * strides, ivec4(1)));
13 }
14 
align_up_4(int v)15 int align_up_4(int v) {
16   return ((v + 4 - 1) / 4) * 4;
17 }
18 
19 // Return the x, y, z and index value the channel-packed 3D tensor from the {n,
20 // c, h, w}-index.
get_channel_packed_pos_from_index(ivec4 nchw,ivec4 sizes)21 ivec4 get_channel_packed_pos_from_index(ivec4 nchw, ivec4 sizes) {
22   int n = nchw.x;
23   int c = nchw.y;
24   int h = nchw.z;
25   int w = nchw.w;
26 
27   int aligned_c = align_up_4(sizes.y);
28   int c_stride = aligned_c / 4;
29 
30   return ivec4(
31       w, // x
32       h, // y
33       n * c_stride + c / 4, // z
34       c % 4);
35 }
36