• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ===========================================================================
15"""GraphKernel cost model"""
16
17
18class Utils:
19    """Model utils"""
20
21    def __init__(self):
22        pass
23
24    @staticmethod
25    def get_attr_type(attr):
26        """Get attr type"""
27        if isinstance(attr, bool):
28            return 'bool'
29        if isinstance(attr, str):
30            return 'str'
31        if isinstance(attr, int):
32            return 'int'
33        if isinstance(attr, float):
34            return 'bool'
35        if isinstance(attr, (list, tuple)):
36            if not attr:
37                raise ValueError("Length of attr is 0")
38            if isinstance(attr[0], int):
39                return 'listInt'
40            if isinstance(attr[0], str):
41                return 'listStr'
42        raise ValueError("Unknown type of attr: {}".format(attr))
43
44
45class DataFormat:
46    """DataFormat"""
47    DEFAULT = "DefaultFormat"
48    NC1KHKWHWC0 = "NC1KHKWHWC0"
49    ND = "ND"
50    NCHW = "NCHW"
51    NHWC = "NHWC"
52    HWCN = "HWCN"
53    NC1HWC0 = "NC1HWC0"
54    FRAC_Z = "FracZ"
55    FRAC_NZ = "FRACTAL_NZ"
56    C1HWNCOC0 = "C1HWNCoC0"
57    NC1HWC0_C04 = "NC1HWC0_C04"
58    FRACTAL_Z_C04 = "FRACTAL_Z_C04"
59    NDHWC = "NDHWC"
60
61    def __init__(self):
62        pass
63
64
65class DataType:
66    """Data Type"""
67    FLOAT = "float"
68    FLOAT16 = "float16"
69    FLOAT32 = "float32"
70    FLOAT64 = "float64"
71    INT = "int"
72    INT8 = "int8"
73    INT16 = "int16"
74    INT32 = "int32"
75    INT64 = "int64"
76    UINT = "uint"
77    UINT8 = "uint8"
78    UINT16 = "uint16"
79    UINT32 = "uint32"
80    UINT64 = "uint64"
81    BOOL = "bool"
82
83    def __init__(self):
84        pass
85
86
87class PrimLib:
88    """Prim lib"""
89
90    UNKNOWN = 0
91    RESHAPE = 1
92    ELEMWISE = 2
93    BROADCAST = 3
94    REDUCE = 4
95    OPAQUE = 5
96
97    def __init__(self):
98        pass
99
100    class Prim:
101        """Prim"""
102
103        def __init__(self, iter_type, calibrate=1, relation_func=None):
104            self.iter_type = iter_type
105            self.calibrate = calibrate
106            self.relation_func = relation_func
107            if relation_func is None:
108                self.relation_func = lambda *x: self.default_relation_func[iter_type](self, *x)
109
110        def default_reshape_relation(self, op, input_idx):
111            """Process reshape relation"""
112            axis_relation, elem_relation = self.unknown_relation(op, input_idx)
113            elem_relation = [PrimLib.RESHAPE] * len(elem_relation)
114            return axis_relation, elem_relation
115
116        def default_elemwise_broadcast_relation(self, op, input_idx):
117            """Process elemwise and broadcast relation"""
118            out_shape = op.output.shape
119            in_shape = op.inputs[input_idx].shape
120            if len(out_shape) < len(in_shape):
121                raise ValueError("input/output size is abnormal")
122            axis_relation, elem_relation = [], []
123            delta = len(out_shape) - len(in_shape)
124            if delta > 0:
125                for i in range(0, delta):
126                    axis_relation.append(None)
127                    elem_relation.append(None)
128            for i, _ in enumerate(in_shape):
129                axis_relation.append(i)
130                elem_relation.append(
131                    PrimLib.ELEMWISE if out_shape[i + delta] == in_shape[i] else PrimLib.BROADCAST)
132            return axis_relation, elem_relation
133
134        def default_reduce_relation(self, op, input_idx):
135            """Process reduce relation"""
136            axis_relation, elem_relation = self.default_elemwise_broadcast_relation(op, input_idx)
137            for i in op.attrs['reduce_axis']:
138                elem_relation[i] = PrimLib.REDUCE
139            return axis_relation, elem_relation
140
141        def unknown_relation(self, op, input_idx):
142            """Process unknown relation"""
143            out_shape = op.output.shape
144            in_shape = op.inputs[input_idx].shape
145            all_relation = list(range(len(in_shape)))
146            axis_relation = [all_relation for i in range(0, len(out_shape))]
147            elem_relation = [PrimLib.UNKNOWN for i in range(0, len(out_shape))]
148            return axis_relation, elem_relation
149
150        default_relation_func = [
151            unknown_relation,
152            default_reshape_relation,
153            default_elemwise_broadcast_relation,
154            default_elemwise_broadcast_relation,
155            default_reduce_relation,
156            unknown_relation,
157        ]
158
159    primtives = {
160        'Add': Prim(ELEMWISE),
161        'Abs': Prim(ELEMWISE),
162        'Neg': Prim(ELEMWISE),
163        'Mul': Prim(ELEMWISE),
164        'Sub': Prim(ELEMWISE),
165        'Log': Prim(ELEMWISE),
166        'IsNan': Prim(ELEMWISE),
167        'IsInf': Prim(ELEMWISE),
168        'IsFinite': Prim(ELEMWISE),
169        'Exp': Prim(ELEMWISE),
170        'Rsqrt': Prim(ELEMWISE),
171        'Sqrt': Prim(ELEMWISE),
172        'Div': Prim(ELEMWISE),
173        'FloorDiv': Prim(ELEMWISE),
174        'RealDiv': Prim(ELEMWISE),
175        'Mod': Prim(ELEMWISE),
176        'Floor': Prim(ELEMWISE),
177        'FloorMod': Prim(ELEMWISE),
178        'Erf': Prim(ELEMWISE),
179        'Erfc': Prim(ELEMWISE),
180        'Cast': Prim(ELEMWISE),
181        'Pow': Prim(ELEMWISE),
182        'Minimum': Prim(ELEMWISE),
183        'Maximum': Prim(ELEMWISE),
184        'Reciprocal': Prim(ELEMWISE),
185        'Equal': Prim(ELEMWISE),
186        'NotEqual': Prim(ELEMWISE),
187        'Greater': Prim(ELEMWISE),
188        'GreaterEqual': Prim(ELEMWISE),
189        'Less': Prim(ELEMWISE),
190        'LessEqual': Prim(ELEMWISE),
191        'LogicalNot': Prim(ELEMWISE),
192        'LogicalAnd': Prim(ELEMWISE),
193        'LogicalOr': Prim(ELEMWISE),
194        'Square': Prim(ELEMWISE),
195        'AddN': Prim(ELEMWISE),
196        'Select': Prim(ELEMWISE, 8),
197        'ReduceSum': Prim(REDUCE),
198        'ReduceMax': Prim(REDUCE),
199        'ReduceMin': Prim(REDUCE),
200        'Argmax': Prim(REDUCE),
201        'Argmin': Prim(REDUCE),
202        'Assign': Prim(ELEMWISE),
203        'Sign': Prim(ELEMWISE),
204        'Sin': Prim(ELEMWISE),
205        'Cos': Prim(ELEMWISE),
206        'Asin': Prim(ELEMWISE),
207        'ACos': Prim(ELEMWISE),
208        'Tanh': Prim(ELEMWISE),
209        'Asinh': Prim(ELEMWISE),
210        'Acosh': Prim(ELEMWISE),
211        'InplaceAssign': Prim(ELEMWISE),
212        '@ReduceInit': Prim(ELEMWISE),
213        'Reshape': Prim(RESHAPE),
214        'Squeeze': Prim(RESHAPE),
215        'Flatten': Prim(RESHAPE),
216        'FlattenGrad': Prim(RESHAPE),
217        'Transpose': Prim(OPAQUE),
218        'Tile': Prim(BROADCAST),
219        'BroadcastTo': Prim(BROADCAST),
220        'StridedSlice': Prim(OPAQUE),
221        'MatMul': Prim(OPAQUE),
222        'TransData': Prim(OPAQUE),
223        'BatchMatMul': Prim(OPAQUE),
224        'UnPadAkg': Prim(OPAQUE),
225        'PadAkg': Prim(OPAQUE),
226        'Conv2D': Prim(OPAQUE),
227        'CReal': Prim(ELEMWISE),
228        'CImag': Prim(ELEMWISE),
229        'Complex': Prim(ELEMWISE),
230        'Atan': Prim(ELEMWISE),
231        'Atan2': Prim(ELEMWISE),
232        'Expm1': Prim(ELEMWISE),
233        'TensorScatterAdd': Prim(OPAQUE),
234        'Gather': Prim(OPAQUE),
235        'GatherNd': Prim(OPAQUE),
236        'UnsortedSegmentSum': Prim(OPAQUE),
237        'StandardNormal': Prim(OPAQUE),
238        'UserDefined': Prim(OPAQUE),
239    }
240
241    default_primtive = Prim(UNKNOWN)
242
243    @classmethod
244    def get_prim(cls, op):
245        """Get op primtive"""
246        prim = cls.primtives.get(op.prim, None)
247        if prim is None:
248            print('[WARN] primtive is not registered: ' + op.prim)
249            prim = cls.default_primtive
250        return prim
251
252    @classmethod
253    def input_relation(cls, op, input_idx):
254        """Get op's input_relation according to input_idx"""
255        return cls.get_prim(op).relation_func(op, input_idx)
256
257    @classmethod
258    def iter_type(cls, op):
259        """Get op's iter type"""
260        return cls.get_prim(op).iter_type
261
262    @classmethod
263    def is_reduce(cls, op):
264        """Check whether op's iter type is reduce"""
265        return cls.get_prim(op).iter_type == cls.REDUCE
266
267    @classmethod
268    def calibrate_iter_size(cls, op, iter_size):
269        """Get calibrate_iter_size"""
270        return cls.get_prim(op).calibrate * iter_size
271
272    @classmethod
273    def dtype_bytes(cls, dtype):
274        """Get dtype bytes"""
275        bits, unit = 1, 1
276        for i in range(len(dtype) - 1, 0, -1):
277            if dtype[i].isdecimal():
278                bits += int(dtype[i]) * unit
279                unit *= 10
280            else:
281                break
282        return bits // 8
283
284    @classmethod
285    def inplace_reuse(cls, op, input_idx, start_axis=0):
286        """Check whether op is inplace reuse"""
287        if cls.dtype_bytes(op.output.dtype) > cls.dtype_bytes(op.inputs[input_idx].dtype):
288            return False
289        _, elem_relation = cls.get_prim(op).relation_func(op, input_idx)
290        for i in range(start_axis, len(elem_relation)):
291            if elem_relation[i] != cls.ELEMWISE:
292                return False
293        return True
294
295
296class Tensor:
297    """Tensor"""
298
299    PARA_NONE = 0
300    PARA_INPUT = 1
301    PARA_OUTPUT = 2
302
303    class Buddy:
304        """Buddy"""
305
306        def __init__(self, leader):
307            self.members = [leader]
308
309    def __init__(self, name, shape, dtype, data_format=DataFormat.DEFAULT, para_type=0):
310        self.name = name
311        self.shape = shape
312        self.dtype = dtype
313        self.data_format = data_format
314        self.para_type = para_type
315        self.op = None
316        self.to_ops = []
317        self.buddy = None
318
319    def __str__(self):
320        return self.name + str(list(self.shape))
321
322    def __repr__(self):
323        return "%s.%s%s" % (self.name, self.dtype, str(list(self.shape)))
324
325    def get_size(self):
326        """Get size"""
327        size = PrimLib.dtype_bytes(self.dtype)
328        for i in self.shape:
329            size *= i
330        return size
331
332    def add_buddy(self, tensor):
333        """Add buddy"""
334        if self.buddy is None:
335            self.buddy = self.Buddy(self)
336        self.buddy.members.append(tensor)
337        tensor.buddy = self.buddy
338
339
340class Value:
341    """Value"""
342
343    def __init__(self, name, dtype, value, data_format=DataFormat.DEFAULT):
344        self.name = name
345        self.shape = [1]
346        self.dtype = dtype
347        self.value = value
348        self.data_format = data_format
349
350    def __str__(self):
351        return self.name + str(list(self.shape))
352
353    def __repr__(self):
354        return "%s.%s%s" % (self.name, self.dtype, str(list(self.shape)))
355
356    def get_size(self):
357        """Get size"""
358        return 1
359
360
361class Operator:
362    """Operator"""
363
364    def __init__(self, primtive, inputs, output, attrs):
365        self.prim = primtive
366        self.inputs = inputs
367        self.output = output
368        self.attrs = attrs
369        for t in inputs:
370            t.to_ops.append(self)
371        if output.op is None:
372            output.op = self
373        self.all_inputs = []  # include Tensor inputs and Value inputs.
374
375    def __str__(self):
376        args = ', '.join([str(t) for t in self.all_inputs])
377        expr = "%s = %s.%s(%s) id:%s" % (
378            str(self.output), self.prim, self.output.dtype, args, id(self))
379        return expr if not self.attrs else '%s // %s' % (expr, str(self.attrs))
380
381    def __repr__(self):
382        return str(self)
383
384
385class Graph:
386    """Graph"""
387
388    def __init__(self, name, ops, stitch_info=None, recompute_ops=None):
389        self.name = name
390        self.ops = ops  # in topo order, can not use set
391        self.inputs = []
392        self.outputs = []
393        self.stitch_info = stitch_info
394        self.recompute_ops = recompute_ops
395        self.processor = ""
396
397    def set_processor(self, processor):
398        """Set processor"""
399        self.processor = processor
400
401    def add(self, ops):
402        """Add ops"""
403        if isinstance(ops, Operator):
404            self.ops.append(ops)
405        else:
406            self.ops.extend(ops)
407
408    def extract_subgraph(self, graph_name, tensor_names, difference=False):
409        """Extract subgraph from this graph"""
410        graph = Graph(graph_name, [])
411        outputs = set(tensor_names)
412        if difference:
413            for op in self.ops:
414                if op.output.name not in outputs:
415                    graph.add(op)
416        else:
417            for op in self.ops:
418                if op.output.name in outputs:
419                    graph.add(op)
420                    outputs.remove(op.output.name)
421            for name in outputs:
422                raise ValueError("invalid input tensor : " + name)
423        return graph
424
425    def deduce_parameters(self):
426        """Deduce parameters"""
427        inputs, outputs = [], []
428        for op in self.ops:
429            for t in op.inputs:
430                if t not in inputs and t.op not in self.ops:
431                    inputs.append(t)
432            if op.output in outputs:
433                continue
434            if op.output.para_type == Tensor.PARA_OUTPUT or not op.output.to_ops:
435                outputs.append(op.output)
436                continue
437            if any([succ not in self.ops for succ in op.output.to_ops]):
438                outputs.append(op.output)
439        if self.inputs:
440            inputs = self.inputs
441
442        if self.outputs:
443            outputs = self.outputs
444        return inputs, outputs
445
446    def __str__(self):
447        inputs, outputs = self.deduce_parameters()
448        para_str = ', '.join([repr(t) for t in inputs])
449        out_str = ', '.join([repr(t) for t in outputs])
450        lines = []
451        lines.append("%s(%s) -> %s {" % (self.name, para_str, out_str))
452        if self.stitch_info:
453            if self.stitch_info.stitch_ops:
454                lines.append('  stitch -> ' + str(self.stitch_info.stitch_ops))
455            if self.stitch_info.stitch_atomic_ops:
456                lines.append('  stitch_atomic_ops-> ' + str(self.stitch_info.stitch_atomic_ops))
457
458        for op in self.ops:
459            lines.append('  ' + str(op))
460        lines.append('}')
461        return '\n'.join(lines)
462
463    def __repr__(self):
464        return str(self)
465
466    def dump(self):
467        """Dump Graph to json"""
468        attr_name = {'reduce_axis': 'axis'}
469        inputs, outputs = self.deduce_parameters()
470        input_desc, output_desc, op_desc = [], [], []
471        for t in inputs:
472            input_desc.append([{'data_type': t.dtype, 'shape': t.shape,
473                                'tensor_name': t.name, 'format': t.data_format}])
474        for t in outputs:
475            output_desc.append({'data_type': t.dtype, 'shape': t.shape,
476                                'tensor_name': t.name, 'format': t.data_format})
477        for op in self.ops:
478            attrs, in_desc = [], []
479            for a in op.attrs:
480                name = attr_name.get(a, a)
481                attrs.append(
482                    {'name': name, 'value': op.attrs[a], 'data_type': Utils.get_attr_type(op.attrs[a])})
483            for t in op.all_inputs:
484                if isinstance(t, Tensor):
485                    in_desc.append([{'data_type': t.dtype, 'name': '', 'shape': t.shape,
486                                     'tensor_name': t.name, 'format': t.data_format}])
487                else:
488                    in_desc.append([{'data_type': t.dtype, 'value': t.value, 'name': '', 'shape': t.shape,
489                                     'tensor_name': t.name, 'format': t.data_format}])
490            out_desc = [{'data_type': op.output.dtype, 'name': '', 'shape': op.output.shape,
491                         'tensor_name': op.output.name, 'format': op.output.data_format}]
492            op_desc.append({'attr': attrs, 'impl_path': '',
493                            'input_desc': in_desc, 'name': op.prim, 'output_desc': out_desc})
494
495        graph_desc = {'composite': True, 'composite_graph': '', 'id': 0,
496                      'input_desc': input_desc, 'op': self.name, 'op_desc': op_desc, 'output_desc': output_desc,
497                      'platform': 'AKG', 'process': self.processor}
498
499        if self.stitch_info and self.stitch_info.stitch_ops:
500            buffer_stitch = {'stitch_op': list(self.stitch_info.stitch_ops)}
501            if self.stitch_info.stitch_atomic_ops:
502                buffer_stitch['stitch_atomic_op'] = list(self.stitch_info.stitch_atomic_ops)
503            graph_desc['buffer_stitch'] = buffer_stitch
504
505        return graph_desc
506
507
508class GraphVisitor:
509    """Graph visitor"""
510
511    def __init__(self, forward=True):
512        self.forward = forward
513
514    def visit_graph(self, graph):
515        """Visit graph"""
516        if self.forward:
517            for op in graph.ops:
518                self.visit(op)
519        else:
520            for i in range(len(graph.ops)-1, -1, -1):
521                self.visit(graph.ops[i])
522
523
524class AlignShape(GraphVisitor):
525    """Align shape"""
526
527    def __init__(self):
528        super(AlignShape, self).__init__()
529
530    def visit(self, op):
531        """Visit op node"""
532        prim = PrimLib.get_prim(op)
533        if prim.iter_type in (PrimLib.ELEMWISE, PrimLib.BROADCAST, PrimLib.REDUCE):
534            out_dim = len(op.output.shape)
535            align_dim = out_dim
536            for t in op.inputs:
537                if len(t.shape) > align_dim:
538                    align_dim = len(t.shape)
539            if align_dim > out_dim:
540                op.output.shape = [1] * (align_dim - out_dim) + op.output.shape
541
542
543class AddControlBuddy(GraphVisitor):
544    """Add control buddy"""
545
546    def __init__(self):
547        super(AddControlBuddy, self).__init__()
548        self.buddies = {}  # {op : [ctrl_op]}
549
550    def visit(self, op):
551        """Visit op node"""
552        if op.prim == "MakeTuple":
553            if len(op.output.to_ops) != 1:
554                raise ValueError("operator's output size is abnormal")
555            owner = op.output.to_ops[0]
556            if owner in self.buddies:
557                self.buddies[owner].append(op)
558            else:
559                self.buddies[owner] = [op]
560            if op in self.buddies:
561                ops = self.buddies.pop(op)
562                self.buddies[owner].extend(ops)
563
564    def visit_graph(self, graph):
565        """Visit graph nodes"""
566        super(AddControlBuddy, self).visit_graph(graph)
567        for owner in self.buddies:
568            for op in self.buddies[owner]:
569                owner.add_buddy(op.output)
570
571
572class GraphKernelUnsupportedException(Exception):
573    """"GraphKernel Unsupported Exception"""
574
575    def __init__(self, message):
576        super(GraphKernelUnsupportedException, self).__init__()
577        self.message = message
578