1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15""" dense sparse to densne matmul""" 16from __future__ import absolute_import 17from te import tik 18from topi.cce import util 19from mindspore.ops.op_info_register import DataType, TBERegOp, op_info_register 20 21 22dsd_matmul_info = TBERegOp('DSDMatmul') \ 23 .fusion_type("OPAQUE") \ 24 .async_flag(False) \ 25 .binfile_name("dsdmatmul.so") \ 26 .compute_cost(10) \ 27 .kernel_name("dsd_matmul") \ 28 .partial_flag(True) \ 29 .input(0, "input_w1", False, "required", "all") \ 30 .input(1, "input_w2", False, "required", "all") \ 31 .input(2, "input_v", False, "required", "all") \ 32 .output(0, "output_y", False, "required", "all") \ 33 .dtype_format(DataType.F16_Default, DataType.F16_Default, DataType.F16_Default, DataType.F16_Default) \ 34 .get_op_info() 35 36 37@op_info_register(dsd_matmul_info) 38def dsd_matmul(input_w1, input_w2, input_v, output_y={}, kernel_name='dsd_matmul'): 39 """ dense sparse to densne matmul""" 40 if util.get_product_version() == util.VERSION_MINI: 41 tik_inst = tik.Tik(tik.Dprofile("v100", "mini")) 42 else: 43 tik_inst = tik.Tik(tik.Dprofile("v100", "cloud")) 44 45 # shape is: (batch_size, head, block_num, block_size//16, 16, head_size//16, 16) 46 input_w1_shape = input_w1.get('shape') 47 # shape is: (batch_size, head, block_num, head_size//16, 16, global_size//16, 16) 48 input_w2_shape = input_w2.get('shape') 49 input_v_shape = input_v.get('shape') 50 51 batch_size = input_w1_shape[0] 52 head = input_w1_shape[1] 53 block_num = input_w1_shape[2] 54 block_size = input_w1_shape[4] * 16 55 head_size = input_w1_shape[3] * 16 56 global_size = input_w2_shape[3] * 16 57 v_embedding = input_v_shape[1] * 16 // head 58 seq_len = input_v_shape[0] * 16 // batch_size 59 60 block_bite_size = 32 61 cpt_time = seq_len//512 62 63 w1_gm = tik_inst.Tensor('float16', (batch_size, head, block_num, head_size // 64 16, block_size//16, 16, 16), name='w1_gm', scope=tik.scope_gm) 65 w2_gm = tik_inst.Tensor('float16', (batch_size, head, block_num, global_size // 66 16, head_size//16, 16, 16), name='w2_gm', scope=tik.scope_gm) 67 # 68 v_gm = tik_inst.Tensor('float16', (batch_size*seq_len//16, 69 head*v_embedding//16, 16, 16), name='v_gm', scope=tik.scope_gm) 70 # zN 71 output_gm = tik_inst.Tensor('float16', (batch_size, head, v_embedding // 16, seq_len//16, 16, 16), name='output_gm', 72 scope=tik.scope_gm) 73 74 channel_num = batch_size*head 75 with tik_inst.for_range(0, channel_num, block_num=channel_num) as channel_idx: 76 head_idx = channel_idx // batch_size 77 bs_idx = channel_idx % batch_size 78 output_l0c = tik_inst.Tensor("float32", (v_embedding // 16, block_size // 16, 16, 16), name='output_l0c', 79 scope=tik.scope_cc) 80 output_ub_32 = tik_inst.Tensor('float32', (v_embedding // 16, block_size // 16, 16, 16), name='output_ub_32', 81 scope=tik.scope_ubuf) 82 output_ub = tik_inst.Tensor('float16', (v_embedding // 16, block_size // 16, 16, 16), name='output_ub', 83 scope=tik.scope_ubuf) 84 # zZ 85 w1_l1 = tik_inst.Tensor( 86 'float16', (block_size//16, head_size//16, 16, 16), name='w1_l1', scope=tik.scope_cbuf) 87 # nZ 88 v_local_l1 = tik_inst.Tensor( 89 'float16', (head_size//16, v_embedding//16, 16, 16), name='v_local_l1', scope=tik.scope_cbuf) 90 # zZ 91 w2_l1 = tik_inst.Tensor('float16', (head_size//16, global_size//(16*cpt_time), 16, 16), 92 name='w2_l1', scope=tik.scope_cbuf) 93 # nZ 94 # use same v_global 95 v_global_l1 = tik_inst.Tensor('float16', (global_size//16, v_embedding//16, 16, 16), 96 name='v_global_l1', scope=tik.scope_cbuf) 97 # global v 98 global_idx = 3 - head_idx % 4 99 tik_inst.data_move(v_global_l1[0, 0, 0, 0], v_gm[bs_idx * seq_len // 16 + global_idx, 100 head_idx * v_embedding // 16, 0, 0], 0, seq_len // (4 * 16), 101 16 * v_embedding * 2 // block_bite_size, 102 (4 * head * v_embedding * 16 - 16 * v_embedding) * 2 // block_bite_size, 0) 103 # every block size is 64, the output of the local and global is (1024,128) Zn 104 with tik_inst.for_range(0, block_num, thread_num=2) as w_idx: 105 # global 106 with tik_inst.new_stmt_scope(): 107 w2_l0a = tik_inst.Tensor('float16', (head_size//16, global_size//(cpt_time*16), 16, 16), 108 name='w2_l0a', scope=tik.scope_ca) 109 v_global_l0b = tik_inst.Tensor('float16', (global_size//(cpt_time*16), v_embedding//16, 16, 16), 110 name='v_global_l0b', scope=tik.scope_cb) 111 with tik_inst.for_range(0, cpt_time) as cpt_idx: 112 with tik_inst.for_range(0, head_size//16) as brick_i: 113 tik_inst.data_move(w2_l1[brick_i, 0, 0, 0], 114 w2_gm[bs_idx, head_idx, w_idx, cpt_idx * 115 global_size//(16*cpt_time), brick_i, 0, 0], 0, 116 global_size//(16*cpt_time), 16 * 16*2//block_bite_size, 117 (block_size//16-1)*16*16*2//block_bite_size, 0) 118 tik_inst.load2dv1( 119 w2_l0a[0, 0, 0, 0], w2_l1[0, 0, 0, 0], 0, block_size*global_size//(cpt_time*16*16), 1, 0) 120 121 tik_inst.load2dv1(v_global_l0b[0, 0, 0, 0], v_global_l1[cpt_idx*global_size//( 122 16*cpt_time), 0, 0, 0], 0, global_size*v_embedding//(16*16*cpt_time), 1, 0) 123 124 with tik_inst.if_scope(cpt_idx == 0): 125 tik_inst.mmad(output_l0c, w2_l0a, v_global_l0b, 126 block_size, global_size//cpt_time, v_embedding, 0) 127 with tik_inst.else_scope(): 128 tik_inst.mmad(output_l0c, w2_l0a, v_global_l0b, 129 block_size, global_size//cpt_time, v_embedding, 1) 130 # local 131 with tik_inst.new_stmt_scope(): 132 w1_l0a = tik_inst.Tensor('float16', (block_size//16, head_size//16, 16, 16), 133 name='w1_l0a', scope=tik.scope_ca) 134 v_local_l0b = tik_inst.Tensor('float16', (head_size//16, v_embedding//16, 16, 16), 135 name='v_local_l0b', scope=tik.scope_cb) 136 tik_inst.data_move(v_local_l1[0, 0, 0, 0], 137 v_gm[bs_idx * seq_len//16 + w_idx * 4, head_idx * 138 v_embedding//16, 0, 0], 0, block_size//16, 139 16 * v_embedding * 2 // block_bite_size, 140 16 * (head-1)*v_embedding*2//block_bite_size, 0) 141 tik_inst.load2dv1(v_local_l0b[0, 0, 0, 0], v_local_l1[0, 0, 0, 0], 0, 142 head_size*v_embedding//(16*16), 1, 0) 143 # w 144 with tik_inst.for_range(0, block_size // 16) as brick_i: 145 tik_inst.data_move(w1_l1[brick_i, 0, 0, 0], w1_gm[bs_idx, head_idx, w_idx, 0, brick_i, 0, 0], 0, 146 head_size // 16, (16*16*2)//block_bite_size, 147 (block_size // 16 - 1) * 16 * 16 * 2 // block_bite_size, 0) 148 tik_inst.load2dv1(w1_l0a[0, 0, 0, 0], w1_l1[0, 0, 0, 0], 0, block_size*head_size//(16*16), 1, 0) 149 tik_inst.mmad(output_l0c, w1_l0a, v_local_l0b, 150 block_size, head_size, v_embedding, 1) 151 tik_inst.data_move(output_ub_32[0, 0, 0, 0], output_l0c[0, 0, 0, 0], 0, 152 1, block_size * v_embedding * 4 // 1024, 0, 0) 153 tik_inst.vconv(64, '', output_ub[0, 0, 0, 0], output_ub_32[0, 0, 0, 0], 154 v_embedding * block_size//64, 1, 1, 4, 8) 155 tik_inst.data_move(output_gm[bs_idx, head_idx, 0, w_idx*(block_size//16), 0, 0], output_ub[0, 0, 0, 0], 156 0, v_embedding//16, 16*block_size*2//block_bite_size, 0, 157 (seq_len - block_size)*16*2//block_bite_size) 158 tik_inst.BuildCCE(kernel_name=kernel_name, inputs=[w1_gm, w2_gm, v_gm], 159 outputs=[output_gm]) 160 return tik_inst 161