• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15import te.lang.cce
16from te import tvm
17from te.platform import CUBE_MKN
18from topi import generic
19from topi.cce import util
20from topi.cce.util import is_v200_version
21
22# pylint: disable=R0912,R0913,R0914,R0915,E1101
23# the dim of shape in conv must be 4
24PAD_SHAPE_DIM = 2
25
26NONETYPE = type(None)
27
28
29@util.check_input_type((list, tuple), (list, tuple), str, str, str, (list, int), (list, int),
30                       int, int, (list, tuple), (list, tuple),
31                       str, str, str,
32                       str, str, str,
33                       str, bool, str)
34def conv_layer_cce_para_check(shape_in, shape_w, in_dtype, w_dtype, res_dtype, padh, padw,
35                              strideh, stridew, quantize_config, scale_sqrt,
36                              scale_q_dtype, offset_q_dtype, scale_dq_dtype,
37                              scale_rq_dtype, offset_rq_dtype, offset_w_dtype,
38                              offset_pad_dtype, bias, kernel_name):
39    # conv shape check
40    util.check_kernel_name(kernel_name)
41
42    # conv data type check
43    util.check_dtype_rule(in_dtype, ['float16', 'int8', 'uint8'])
44    util.check_dtype_rule(w_dtype, ['float16', 'int8', 'uint8'])
45    res_dtype_list = ['float16', 'int8', 'uint8']
46    if is_v200_version():
47        res_dtype_list.append('int32')
48    util.check_dtype_rule(res_dtype, res_dtype_list)
49    util.check_dtype_rule(scale_q_dtype, ['float16'])
50    util.check_dtype_rule(offset_q_dtype, ['float16'])
51    util.check_dtype_rule(scale_dq_dtype, ['float16'])
52    util.check_dtype_rule(scale_rq_dtype, ['float16'])
53    util.check_dtype_rule(offset_rq_dtype, ['float16'])
54    util.check_dtype_rule(offset_w_dtype, ['int32'])
55    util.check_dtype_rule(offset_pad_dtype, ['uint8'])
56
57    if not isinstance(bias, bool):
58        raise RuntimeError("bias dtype should be bool.")
59
60    if quantize_config[0] == 0:
61        if is_v200_version():
62            util.check_dtype_rule(in_dtype, ('int8',))
63            util.check_dtype_rule(w_dtype, ('int8',))
64            util.check_dtype_rule(res_dtype, ('int32',))
65        else:
66            util.check_dtype_rule(in_dtype, ['float16'])
67            util.check_dtype_rule(w_dtype, ['float16'])
68            util.check_dtype_rule(res_dtype, ['float16'])
69
70    if quantize_config[0] == 1:
71        util.check_dtype_rule(w_dtype, ['int8'])
72        if quantize_config[1] == 0:
73            util.check_dtype_rule(in_dtype, ['int8', 'float16'])
74            util.check_dtype_rule(res_dtype, ['int8', 'float16'])
75        elif quantize_config[1] == 1:
76            util.check_dtype_rule(in_dtype, ['uint8', 'float16'])
77            util.check_dtype_rule(res_dtype, ['uint8', 'float16'])
78        elif quantize_config[1] == 2:
79            raise RuntimeError("All Offset mode quantize not support.")
80        else:
81            raise RuntimeError("Invalid quantize algorithm.")
82
83    # quantize switch on
84    if quantize_config[0] == 1:
85        # quantize -> DeQuantize dataflow
86        if in_dtype == 'float16' and w_dtype == 'int8' and res_dtype == 'float16':
87            pass
88        # DeQuantize dataflow
89        elif (in_dtype in ['int8', 'uint8'] and w_dtype == 'int8' and
90              res_dtype == 'float16'):
91            pass
92        # quantize -> ReQuantize dataflow
93        elif (in_dtype == 'float16' and w_dtype == 'int8' and res_dtype in
94              ['int8', 'uint8']):
95            pass
96        # ReQuantize dataflow
97        elif (in_dtype in ['int8', 'uint8'] and w_dtype == 'int8' and res_dtype in
98              ['int8', 'uint8']):
99            pass
100        else:
101            raise RuntimeError("Not support in/out data type for quantize.")
102
103        if quantize_config not in ([1, 0, 0], [1, 1, 0], [1, 0, 1], [1, 1, 1]):
104            raise RuntimeError("Invalid Quantize Config.")
105
106        if scale_sqrt not in ([0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1],
107                              [1, 0, 1], [0, 1, 1], [1, 1, 1]):
108            raise RuntimeError("Invalid Quantize Config.")
109
110    # quantize switch off
111    elif quantize_config[0] == 0:
112        if quantize_config != [0, 0, 0]:
113            raise RuntimeError("Invalid Quantize Config.")
114        if scale_sqrt != [0, 0, 0]:
115            raise RuntimeError("Invalid Quantize Config.")
116    else:
117        raise RuntimeError("Invalid Quantize Config.")
118
119    if isinstance(padh, list):
120        if len(padh) != PAD_SHAPE_DIM:
121            raise RuntimeError("Dimension must be %d when padh is a list." % PAD_SHAPE_DIM)
122        pad_top = padh[0]
123        pad_bottom = padh[1]
124    else:
125        pad_top = padh
126        pad_bottom = padh
127
128    if isinstance(padw, list):
129        if len(padw) != PAD_SHAPE_DIM:
130            raise RuntimeError("Dimension must be %d when padw is a list." % PAD_SHAPE_DIM)
131        pad_left = padw[0]
132        pad_right = padw[1]
133    else:
134        pad_left = padw
135        pad_right = padw
136
137    shape_in, shape_w = te.lang.cce.check_conv_shape(shape_in, shape_w, pad_top, pad_bottom, \
138                                                     pad_left, pad_right, strideh, \
139                                                     stridew, in_dtype, w_dtype, res_dtype)
140
141    return shape_in, shape_w
142
143
144@util.check_input_type((list, tuple), (list, tuple), str, str, str, \
145                       (list, int), (list, int), int, int,
146                       (list, NONETYPE), (list, NONETYPE),
147                       str, str, str,
148                       str, str, str, str,
149                       bool, str, bool, bool)
150def conv_layer_cce(shape_in, shape_w, in_dtype, w_dtype, res_dtype, padh, padw, strideh, stridew,
151                   quantize_config=None, scale_sqrt=None,
152                   scale_q_dtype='float16', offset_q_dtype='float16', scale_dq_dtype='float16',
153                   scale_rq_dtype='float16', offset_rq_dtype='float16', offset_w_dtype='int32',
154                   offset_pad_dtype='uint8', bias=False, kernel_name="cce_conv", need_build=False,
155                   need_print=False):
156    """
157
158    Parameters
159    ----------
160    shape_in : shape of data_in
161
162    shape_w : shape of filter
163
164    in_dtype : the feature map data type
165
166    w_dtype : the weight data type
167
168    res_dtype : the result data type
169
170    padh: the padding shape in H
171
172    padw: the padding shape in weight
173
174    strideh: the stride value in H
175
176    stridew: the stride value in weight
177
178    quantize_config: quantize config table, default [0, 0, 0]
179    quantize_config[0] - quantize function switch
180                        0: quantize off
181                        1: quantize on
182    quantize_config[1] - quantize_algorithm
183                        0: non offset
184                        1: half offset
185                        2: all offset ( Not supported now )
186    quantize_config[2] - QuantizeScaleType (for Dequantize/Requantize, quantize always scalar)
187                        0: scalar
188                        1: vector
189
190    scale_sqrt: scale mode
191    scale_sqrt[0] - Quantize scale mode
192                   0: non sqrt
193                   1: sqrt
194    scale_sqrt[1] - DeQuantize scale mode
195                   0: non sqrt
196                   1: sqrt
197    scale_sqrt[2] - ReQuantize scale mode
198                   0: non sqrt
199                   1: sqrt
200
201    scale_q_dtype: Quantize scale data type, default 'float16'
202
203    offset_q_dtype: Quantize offset data type, default 'float16'
204
205    scale_dq_dtype: DeQuantize scale data type, default 'float16'
206
207    scale_rq_dtype: ReQuantize scale data type, default 'float16'
208
209    offset_rq_dtype: ReQuantize offset data type, default 'float16'
210
211    offset_w_dtype: weight offset data type, default 'int32'
212
213    offset_pad_dtype: Quantize Cube offset data type, default 'uint8'
214
215    bias: the tag for bias or not
216
217    kernel_name : cce kernel name, default value is "cce_conv"
218
219    need_build : if need to build CCEC kernel, default value is False
220
221    need_print : if need to print the ir, default value is False
222
223    Returns
224    -------
225    wrapped_tensor
226
227    """
228    # for pylint, otherwise "Dangerous default value [] as argument"
229    if quantize_config is None:
230        quantize_config = [0, 0, 0]
231    if scale_sqrt is None:
232        scale_sqrt = [0, 0, 0]
233
234    in_dtype = in_dtype.lower()
235    w_dtype = w_dtype.lower()
236    res_dtype = res_dtype.lower()
237    scale_q_dtype = scale_q_dtype.lower()
238    offset_q_dtype = offset_q_dtype.lower()
239    scale_dq_dtype = scale_dq_dtype.lower()
240    scale_rq_dtype = scale_rq_dtype.lower()
241    offset_rq_dtype = offset_rq_dtype.lower()
242    offset_w_dtype = offset_w_dtype.lower()
243    offset_pad_dtype = offset_pad_dtype.lower()
244
245    mad_dtype = 'float32'
246    if w_dtype == 'int8':
247        mad_dtype = 'int32'
248
249    shape_in = list(shape_in)
250    shape_w = list(shape_w)
251
252    shape_in, shape_w = conv_layer_cce_para_check(shape_in, shape_w, in_dtype, w_dtype, res_dtype, padh, padw, strideh,
253                                                  stridew,
254                                                  quantize_config, scale_sqrt, scale_q_dtype, offset_q_dtype,
255                                                  scale_dq_dtype,
256                                                  scale_rq_dtype, offset_rq_dtype, offset_w_dtype, offset_pad_dtype,
257                                                  bias, kernel_name)
258
259    # quantize switch on
260    if quantize_config[0] == 1:
261        quantize_turn_on = True
262        # quantize -> DeQuantize dataflow
263        if in_dtype == 'float16' and w_dtype == 'int8' and res_dtype == 'float16':
264            is_quantize = True
265            is_dequantize = True
266            is_requantize = False
267        # DeQuantize dataflow
268        elif (in_dtype in ['int8', 'uint8'] and w_dtype == 'int8' and
269              res_dtype == 'float16'):
270            is_quantize = False
271            is_dequantize = True
272            is_requantize = False
273        # quantize -> ReQuantize dataflow
274        elif (in_dtype == 'float16' and w_dtype == 'int8' and res_dtype in
275              ['int8', 'uint8']):
276            is_quantize = True
277            is_dequantize = False
278            is_requantize = True
279        # ReQuantize dataflow
280        elif (in_dtype in ['int8', 'uint8'] and w_dtype == 'int8' and res_dtype in
281              ['int8', 'uint8']):
282            is_quantize = False
283            is_dequantize = False
284            is_requantize = True
285        else:
286            raise RuntimeError("Not support in/out data type for quantize.")
287
288    # quantize switch off
289    elif quantize_config[0] == 0:
290        quantize_turn_on = False
291        is_quantize = False
292        is_dequantize = False
293        is_requantize = False
294
295        if quantize_config != [0, 0, 0]:
296            raise RuntimeError("Invalid Quantize Config.")
297        if scale_sqrt != [0, 0, 0]:
298            raise RuntimeError("Invalid Quantize Config.")
299    else:
300        raise RuntimeError("Invalid Quantize Config.")
301
302    batch_size = shape_in[0]
303    in_channel = shape_in[1]
304    feature_map_h = shape_in[2]
305    feature_map_w = shape_in[3]
306    block_size_k = CUBE_MKN[in_dtype]['mac'][1]
307    fmap_shape_nc1hwc0 = (batch_size, (in_channel + block_size_k - 1) // block_size_k,
308                          feature_map_h, feature_map_w, block_size_k)
309
310    out_channel = shape_w[0]
311    in_channel_weight = shape_w[1]
312    filter_h = shape_w[2]
313    filter_w = shape_w[3]
314    block_size_k = CUBE_MKN[w_dtype]['mac'][1]
315    block_size_n = CUBE_MKN[w_dtype]['mac'][2]
316    filter_shape_frac_z = (in_channel_weight * filter_h * filter_w // block_size_k,
317                           out_channel // block_size_n, block_size_n, block_size_k)
318
319    with tvm.target.cce():
320        data = tvm.placeholder(
321            fmap_shape_nc1hwc0, name='Fmap', dtype=in_dtype)
322        weight = tvm.placeholder(
323            filter_shape_frac_z, name='Filter', dtype=w_dtype)
324        bias_tensor = None
325        scale_q = None
326        scale_dq = None
327        scale_rq = None
328        offset_pad = None
329        offset_rq = None
330        offset_q = None
331        scale_drq = None
332
333        # bias or fusion_bias(half offset)
334        if bias or (quantize_config[1] == 1 and quantize_turn_on):
335            bias_tensor = tvm.placeholder(
336                (out_channel,), name='bias_tensor', \
337                dtype="int32" if quantize_turn_on else res_dtype)
338
339        # quantize on
340        if quantize_turn_on:
341            quantize_algorithm = quantize_config[1]
342            if is_quantize:
343                scale_q = tvm.placeholder(
344                    (CUBE_MKN[scale_q_dtype]['mac'][1],), name='scaleQ', dtype=scale_q_dtype)
345                if quantize_algorithm == 1:
346                    offset_q = tvm.placeholder(
347                        (CUBE_MKN[offset_q_dtype]['mac'][1],), name='offsetQ', dtype=offset_q_dtype)
348
349            if is_dequantize:
350                scale_dq_shape = (CUBE_MKN[scale_dq_dtype]['mac'][1],) if quantize_config[2] == 0 \
351                    else (out_channel,)
352                scale_dq = tvm.placeholder(
353                    scale_dq_shape, name='scaleDq', dtype=scale_dq_dtype)
354
355            if is_requantize:
356                scale_rq_shape = (CUBE_MKN[scale_rq_dtype]['mac'][1],) if quantize_config[2] == 0 \
357                    else (out_channel,)
358                scale_rq = tvm.placeholder(
359                    scale_rq_shape, name='scaleRq', dtype=scale_rq_dtype)
360                if quantize_algorithm == 1:
361                    offset_rq_shape = (CUBE_MKN[offset_rq_dtype]['mac'][1],)
362                    offset_rq = tvm.placeholder(
363                        offset_rq_shape, name='offsetRq', dtype=offset_rq_dtype)
364
365            # need offset_pad , for half offset
366            if quantize_algorithm == 1:
367                offset_pad = tvm.placeholder(
368                    (CUBE_MKN[offset_pad_dtype]['mac'][1],), name='offset_pad',
369                    dtype=offset_pad_dtype)
370
371            if quantize_algorithm == 0:
372                if is_quantize:
373                    if is_dequantize:
374                        scale_drq = scale_dq
375                    else:
376                        scale_drq = scale_rq
377
378                    conv_res = te.lang.cce.conv(
379                        data, weight, {"bias_tensor": bias_tensor,
380                                       "scale_q": scale_q,
381                                       "offset_q": offset_q,
382                                       "scale_drq": scale_drq,
383                                       "offset_pad": offset_pad,
384                                       "offset_rq": offset_rq,
385                                       "quantize_config": quantize_config,
386                                       "is_quantize": is_quantize,
387                                       "is_dequantize": is_dequantize,
388                                       "is_requantize": is_requantize,
389                                       "scale_sqrt": scale_sqrt,
390                                       "pad_h": padh, "pad_w": padw,
391                                       "stride_h": strideh, "stride_w": stridew,
392                                       "filter_h": filter_h, "filter_w": filter_w,
393                                       "res_dtype": res_dtype, "mad_dtype": mad_dtype},
394                        dsl_flag=False)
395                    if bias:
396                        tensor_list = [data, weight, bias_tensor, scale_q,
397                                       scale_drq, conv_res]
398                    else:
399                        tensor_list = [data, weight, scale_q,
400                                       scale_drq, conv_res]
401                else:
402                    if is_dequantize:
403                        scale_drq = scale_dq
404                    else:
405                        scale_drq = scale_rq
406                    conv_res = te.lang.cce.conv(
407                        data, weight, {"bias_tensor": bias_tensor,
408                                       "scale_q": scale_q,
409                                       "offset_q": offset_q,
410                                       "scale_drq": scale_drq,
411                                       "offset_pad": offset_pad,
412                                       "offset_rq": offset_rq,
413                                       "quantize_config": quantize_config,
414                                       "is_quantize": is_quantize,
415                                       "is_dequantize": is_dequantize,
416                                       "is_requantize": is_requantize,
417                                       "scale_sqrt": scale_sqrt,
418                                       "pad_h": padh, "pad_w": padw,
419                                       "stride_h": strideh, "stride_w": stridew,
420                                       "filter_h": filter_h, "filter_w": filter_w,
421                                       "res_dtype": res_dtype, "mad_dtype": mad_dtype},
422                        dsl_flag=False)
423                    if bias:
424                        tensor_list = [data, weight, bias_tensor,
425                                       scale_drq, conv_res]
426                    else:
427                        tensor_list = [data, weight,
428                                       scale_drq, conv_res]
429
430            # half offset
431            else:
432                if is_quantize:
433                    if is_dequantize:
434                        scale_drq = scale_dq
435                    else:
436                        scale_drq = scale_rq
437                    conv_res = te.lang.cce.conv(
438                        data, weight, {"bias_tensor": bias_tensor,
439                                       "scale_q": scale_q,
440                                       "offset_q": offset_q,
441                                       "scale_drq": scale_drq,
442                                       "offset_pad": offset_pad,
443                                       "offset_rq": offset_rq,
444                                       "quantize_config": quantize_config,
445                                       "is_quantize": is_quantize,
446                                       "is_dequantize": is_dequantize,
447                                       "is_requantize": is_requantize,
448                                       "scale_sqrt": scale_sqrt,
449                                       "pad_h": padh, "pad_w": padw,
450                                       "stride_h": strideh, "stride_w": stridew,
451                                       "filter_h": filter_h, "filter_w": filter_w,
452                                       "res_dtype": res_dtype, "mad_dtype": mad_dtype},
453                        dsl_flag=False)
454                    if is_dequantize:
455                        tensor_list = [data, weight, bias_tensor, scale_q, offset_q,
456                                       scale_drq, offset_pad, conv_res]
457                    else:
458                        tensor_list = [data, weight, bias_tensor, scale_q, offset_q,
459                                       scale_drq, offset_rq, offset_pad, conv_res]
460                else:
461                    if is_dequantize:
462                        scale_drq = scale_dq
463                    else:
464                        scale_drq = scale_rq
465                    conv_res = te.lang.cce.conv(
466                        data, weight, {"bias_tensor": bias_tensor,
467                                       "scale_q": scale_q,
468                                       "offset_q": offset_q,
469                                       "scale_drq": scale_drq,
470                                       "offset_pad": offset_pad,
471                                       "offset_rq": offset_rq,
472                                       "quantize_config": quantize_config,
473                                       "is_quantize": is_quantize,
474                                       "is_dequantize": is_dequantize,
475                                       "is_requantize": is_requantize,
476                                       "scale_sqrt": scale_sqrt,
477                                       "pad_h": padh, "pad_w": padw,
478                                       "stride_h": strideh, "stride_w": stridew,
479                                       "filter_h": filter_h, "filter_w": filter_w,
480                                       "res_dtype": res_dtype, "mad_dtype": mad_dtype},
481                        dsl_flag=False)
482                    if is_dequantize:
483                        tensor_list = [data, weight, bias_tensor,
484                                       scale_drq, offset_pad, conv_res]
485                    else:
486                        tensor_list = [data, weight, bias_tensor,
487                                       scale_drq, offset_rq, offset_pad, conv_res]
488        else:
489            conv_res = te.lang.cce.conv(
490                data, weight, {"bias_tensor": bias_tensor,
491                               "scale_q": scale_q,
492                               "offset_q": offset_q,
493                               "scale_drq": scale_drq,
494                               "offset_pad": offset_pad,
495                               "offset_rq": offset_rq,
496                               "quantize_config": quantize_config,
497                               "is_quantize": is_quantize,
498                               "is_dequantize": is_dequantize,
499                               "is_requantize": is_requantize,
500                               "scale_sqrt": scale_sqrt,
501                               "pad_h": padh, "pad_w": padw,
502                               "stride_h": strideh, "stride_w": stridew,
503                               "filter_h": filter_h, "filter_w": filter_w,
504                               "res_dtype": res_dtype, "mad_dtype": mad_dtype},
505                dsl_flag=False)
506            if bias:
507                tensor_list = [data, weight, bias_tensor, conv_res]
508            else:
509                tensor_list = [data, weight, conv_res]
510        sch = generic.auto_schedule(conv_res)
511
512    config = {
513        "print_ir": need_print,
514        "need_build": need_build,
515        "name": kernel_name,
516        "tensor_list": tensor_list
517    }
518
519    te.lang.cce.cce_build_code(sch, config)
520