• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15""" dense sparse to densne matmul"""
16from __future__ import absolute_import
17from te import tik
18from topi.cce import util
19from mindspore.ops.op_info_register import DataType, TBERegOp, op_info_register
20
21
22dsd_matmul_info = TBERegOp('DSDMatmul') \
23    .fusion_type("OPAQUE") \
24    .async_flag(False) \
25    .binfile_name("dsdmatmul.so") \
26    .compute_cost(10) \
27    .kernel_name("dsd_matmul") \
28    .partial_flag(True) \
29    .input(0, "input_w1", False, "required", "all") \
30    .input(1, "input_w2", False, "required", "all") \
31    .input(2, "input_v", False, "required", "all") \
32    .output(0, "output_y", False, "required", "all") \
33    .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \
34    .get_op_info()
35
36
37@op_info_register(dsd_matmul_info)
38def dsd_matmul(input_w1, input_w2, input_v, output_y={}, kernel_name='dsd_matmul'):
39    """ dense sparse to densne matmul"""
40    if util.get_product_version() == util.VERSION_MINI:
41        tik_inst = tik.Tik(tik.Dprofile("v100", "mini"))
42    else:
43        tik_inst = tik.Tik(tik.Dprofile("v100", "cloud"))
44
45    # shape is: (batch_size, head, block_num, block_size//16, 16, head_size//16, 16)
46    input_w1_shape = input_w1.get('shape')
47    # shape is: (batch_size, head, block_num, head_size//16, 16, global_size//16, 16)
48    input_w2_shape = input_w2.get('shape')
49    input_v_shape = input_v.get('shape')
50
51    batch_size = input_w1_shape[0]
52    head = input_w1_shape[1]
53    block_num = input_w1_shape[2]
54    block_size = input_w1_shape[4] * 16
55    head_size = input_w1_shape[3] * 16
56    global_size = input_w2_shape[3] * 16
57    v_embedding = input_v_shape[1] * 16 // head
58    seq_len = input_v_shape[0] * 16 // batch_size
59
60    block_bite_size = 32
61    cpt_time = seq_len//512
62
63    w1_gm = tik_inst.Tensor('float16', (batch_size, head, block_num, head_size //
64                                        16, block_size//16, 16, 16), name='w1_gm', scope=tik.scope_gm)
65    w2_gm = tik_inst.Tensor('float16', (batch_size, head, block_num, global_size //
66                                        16, head_size//16, 16, 16), name='w2_gm', scope=tik.scope_gm)
67    #
68    v_gm = tik_inst.Tensor('float16', (batch_size*seq_len//16,
69                                       head*v_embedding//16, 16, 16), name='v_gm', scope=tik.scope_gm)
70    # zN
71    output_gm = tik_inst.Tensor('float16', (batch_size, head, v_embedding // 16, seq_len//16, 16, 16), name='output_gm',
72                                scope=tik.scope_gm)
73
74    channel_num = batch_size*head
75    with tik_inst.for_range(0, channel_num, block_num=channel_num) as channel_idx:
76        head_idx = channel_idx // batch_size
77        bs_idx = channel_idx % batch_size
78        output_l0c = tik_inst.Tensor("float32", (v_embedding // 16, block_size // 16, 16, 16), name='output_l0c',
79                                     scope=tik.scope_cc)
80        output_ub_32 = tik_inst.Tensor('float32', (v_embedding // 16, block_size // 16, 16, 16), name='output_ub_32',
81                                       scope=tik.scope_ubuf)
82        output_ub = tik_inst.Tensor('float16', (v_embedding // 16, block_size // 16, 16, 16), name='output_ub',
83                                    scope=tik.scope_ubuf)
84        # zZ
85        w1_l1 = tik_inst.Tensor(
86            'float16', (block_size//16, head_size//16, 16, 16), name='w1_l1', scope=tik.scope_cbuf)
87        # nZ
88        v_local_l1 = tik_inst.Tensor(
89            'float16', (head_size//16, v_embedding//16, 16, 16), name='v_local_l1', scope=tik.scope_cbuf)
90        # zZ
91        w2_l1 = tik_inst.Tensor('float16', (head_size//16, global_size//(16*cpt_time), 16, 16),
92                                name='w2_l1', scope=tik.scope_cbuf)
93        # nZ
94        # use same v_global
95        v_global_l1 = tik_inst.Tensor('float16', (global_size//16, v_embedding//16, 16, 16),
96                                      name='v_global_l1', scope=tik.scope_cbuf)
97        # global v
98        global_idx = 3 - head_idx % 4
99        tik_inst.data_move(v_global_l1[0, 0, 0, 0], v_gm[bs_idx * seq_len // 16 + global_idx,
100                                                         head_idx * v_embedding // 16, 0, 0], 0, seq_len // (4 * 16),
101                           16 * v_embedding * 2 // block_bite_size,
102                           (4 * head * v_embedding * 16 - 16 * v_embedding) * 2 // block_bite_size, 0)
103        # every block size is 64, the output of the local and global is (1024,128) Zn
104        with tik_inst.for_range(0, block_num, thread_num=2) as w_idx:
105            # global
106            with tik_inst.new_stmt_scope():
107                w2_l0a = tik_inst.Tensor('float16', (head_size//16, global_size//(cpt_time*16), 16, 16),
108                                         name='w2_l0a', scope=tik.scope_ca)
109                v_global_l0b = tik_inst.Tensor('float16', (global_size//(cpt_time*16), v_embedding//16, 16, 16),
110                                               name='v_global_l0b', scope=tik.scope_cb)
111                with tik_inst.for_range(0, cpt_time) as cpt_idx:
112                    with tik_inst.for_range(0, head_size//16) as brick_i:
113                        tik_inst.data_move(w2_l1[brick_i, 0, 0, 0],
114                                           w2_gm[bs_idx, head_idx, w_idx, cpt_idx *
115                                                 global_size//(16*cpt_time), brick_i, 0, 0], 0,
116                                           global_size//(16*cpt_time), 16 * 16*2//block_bite_size,
117                                           (block_size//16-1)*16*16*2//block_bite_size, 0)
118                    tik_inst.load2dv1(
119                        w2_l0a[0, 0, 0, 0], w2_l1[0, 0, 0, 0], 0, block_size*global_size//(cpt_time*16*16), 1, 0)
120
121                    tik_inst.load2dv1(v_global_l0b[0, 0, 0, 0], v_global_l1[cpt_idx*global_size//(
122                        16*cpt_time), 0, 0, 0], 0, global_size*v_embedding//(16*16*cpt_time), 1, 0)
123
124                    with tik_inst.if_scope(cpt_idx == 0):
125                        tik_inst.mmad(output_l0c, w2_l0a, v_global_l0b,
126                                      block_size, global_size//cpt_time, v_embedding, 0)
127                    with tik_inst.else_scope():
128                        tik_inst.mmad(output_l0c, w2_l0a, v_global_l0b,
129                                      block_size, global_size//cpt_time, v_embedding, 1)
130            # local
131            with tik_inst.new_stmt_scope():
132                w1_l0a = tik_inst.Tensor('float16', (block_size//16, head_size//16, 16, 16),
133                                         name='w1_l0a', scope=tik.scope_ca)
134                v_local_l0b = tik_inst.Tensor('float16', (head_size//16, v_embedding//16, 16, 16),
135                                              name='v_local_l0b', scope=tik.scope_cb)
136                tik_inst.data_move(v_local_l1[0, 0, 0, 0],
137                                   v_gm[bs_idx * seq_len//16 + w_idx * 4, head_idx *
138                                        v_embedding//16, 0, 0], 0, block_size//16,
139                                   16 * v_embedding * 2 // block_bite_size,
140                                   16 * (head-1)*v_embedding*2//block_bite_size, 0)
141                tik_inst.load2dv1(v_local_l0b[0, 0, 0, 0], v_local_l1[0, 0, 0, 0], 0,
142                                  head_size*v_embedding//(16*16), 1, 0)
143                # w
144                with tik_inst.for_range(0, block_size // 16) as brick_i:
145                    tik_inst.data_move(w1_l1[brick_i, 0, 0, 0], w1_gm[bs_idx, head_idx, w_idx, 0, brick_i, 0, 0], 0,
146                                       head_size // 16, (16*16*2)//block_bite_size,
147                                       (block_size // 16 - 1) * 16 * 16 * 2 // block_bite_size, 0)
148                tik_inst.load2dv1(w1_l0a[0, 0, 0, 0], w1_l1[0, 0, 0, 0], 0, block_size*head_size//(16*16), 1, 0)
149                tik_inst.mmad(output_l0c, w1_l0a, v_local_l0b,
150                              block_size, head_size, v_embedding, 1)
151                tik_inst.data_move(output_ub_32[0, 0, 0, 0], output_l0c[0, 0, 0, 0], 0,
152                                   1, block_size * v_embedding * 4 // 1024, 0, 0)
153                tik_inst.vconv(64, '', output_ub[0, 0, 0, 0], output_ub_32[0, 0, 0, 0],
154                               v_embedding * block_size//64, 1, 1, 4, 8)
155                tik_inst.data_move(output_gm[bs_idx, head_idx, 0, w_idx*(block_size//16), 0, 0], output_ub[0, 0, 0, 0],
156                                   0, v_embedding//16, 16*block_size*2//block_bite_size, 0,
157                                   (seq_len - block_size)*16*2//block_bite_size)
158    tik_inst.BuildCCE(kernel_name=kernel_name, inputs=[w1_gm, w2_gm, v_gm],
159                      outputs=[output_gm])
160    return tik_inst
161