• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ===========================================================================
15"""Cost model for parallel fusion"""
16from .model import PrimLib
17
18
19class ParalGain:
20    """Paral Gain"""
21
22    def __init__(self, fusion_type, bottleneck, gain, block_assign, type_info):
23        self.fusion_type = fusion_type
24        self.bottleneck = bottleneck
25        self.gain = gain
26        self.block_assign = block_assign
27        self.type_info = type_info
28
29
30class ScheduleAnalyzer:
31    """schedule analyzer"""
32    WRAP_SIZE = 32
33    MAX_SM = 80  # Volta
34    MAX_NUM_THREADS = 1024
35    MAX_BLOCK = 256
36    PIPELINE_OP_THREADHOLD = 5
37
38    def __init__(self, graph):
39        self.graph = graph
40        self.block_num = 0
41        self.block_weight = 0
42        _, outputs = graph.deduce_parameters()
43        self.ops = graph.ops
44        self.dom_op = [out.op for out in outputs]
45
46    @staticmethod
47    def prod(shape):
48        """Compute shape product"""
49        res = shape[0]
50        for i in range(1, len(shape)):
51            res = res * shape[i]
52        return res
53
54    def _cal_weight(self, ops):
55        weight = 0
56        for op in ops:
57            weight += self.prod(op.output.shape) * \
58                PrimLib.dtype_bytes(op.output.dtype)
59        return weight
60
61    def injective_analyze(self):
62        """analyze injective case"""
63        const_size = max([self.prod(op.output.shape) for op in self.dom_op])
64        const_size = (const_size + self.MAX_NUM_THREADS -
65                      1) // self.MAX_NUM_THREADS * self.MAX_NUM_THREADS
66
67        total_weight = self._cal_weight(self.ops)
68        total_block = (const_size + self.MAX_NUM_THREADS -
69                       1) // self.MAX_NUM_THREADS
70        need_block_split = const_size > self.MAX_BLOCK * self.MAX_NUM_THREADS
71        if need_block_split:
72            self.block_num = self.MAX_BLOCK
73            waves = (total_block + self.MAX_BLOCK - 1) // self.MAX_BLOCK
74            self.block_weight = total_weight // total_block * waves
75        else:
76            self.block_num = total_block
77            self.block_weight = total_weight // self.block_num
78
79    def reduce_analyze(self):
80        """analyze reduce case"""
81        thread_x, thread_y = 32, 32
82        reduce_op = None
83        for op in self.ops:
84            if PrimLib.iter_type(op) == PrimLib.REDUCE:
85                if reduce_op:
86                    raise RuntimeError(
87                        "Not support multiply reduce op in a graph now.")
88                reduce_op = op
89        if not reduce_op:
90            raise RuntimeError("Wrong analyze for reduce!")
91        shape = reduce_op.inputs[0].shape
92        reduce_axis = reduce_op.attrs['reduce_axis']
93        total_space = self.prod(shape)
94        red_space = shape[reduce_axis[0]]
95        for i in range(1, len(reduce_axis)):
96            red_space *= shape[reduce_axis[i]]
97        dtype_size = PrimLib.dtype_bytes(reduce_op.output.dtype)
98
99        weight = self._cal_weight(self.ops)  # reduce + injective
100        block_x = (total_space // red_space + thread_y - 1) // thread_y
101        block_w = (weight + block_x - 1) // block_x
102        waves = (block_x + self.MAX_BLOCK - 1) // self.MAX_BLOCK
103        self.block_num = min(self.MAX_BLOCK, block_x)
104        all_reduce = 10  # 1 reduce init + 3 sync + 5 bin + 1 write
105        self.block_weight = (block_w + all_reduce *
106                             dtype_size * thread_x * thread_y) * waves
107
108    def default_analyze(self):
109        """analyze default case"""
110        def _cal_default_space(op):
111            space = self.prod(op.output.shape)
112            for t in op.inputs:
113                size = self.prod(t.shape)
114                if size > space:
115                    space = size
116            return space
117        space = max([_cal_default_space(op) for op in self.dom_op])
118
119        # each sm least 4 wrap
120        block = (space + (self.WRAP_SIZE * 4) - 1) // (self.WRAP_SIZE * 4)
121        self.block_num = min(self.MAX_BLOCK, block)
122        self.block_weight = self._cal_weight(self.ops) // self.block_num
123
124    def analyze(self):
125        """analyze ops"""
126        def _ops_type(ops, dom_op):
127            have_reduce = any(
128                [PrimLib.iter_type(op) == PrimLib.REDUCE for op in ops])
129            if have_reduce:
130                return True
131            return PrimLib.iter_type(dom_op[0])
132
133        dom_type = _ops_type(self.ops, self.dom_op)
134        if dom_type in (PrimLib.ELEMWISE, PrimLib.BROADCAST):
135            self.injective_analyze()
136        elif dom_type == PrimLib.REDUCE:
137            self.reduce_analyze()
138        else:
139            self.default_analyze()
140
141    def suitable_to_pipeline(self):
142        """judge whether is suitable to be pipeline optimized"""
143        # Reduce is not suitable
144        def _contain_reduce(ops):
145            for op in ops:
146                # Reduce may make the tiling bad.
147                if PrimLib.primtives.get(op.prim, None) == PrimLib.REDUCE:
148                    return True
149            return False
150
151        suitable = True
152        if _contain_reduce(self.ops):
153            suitable = False
154        return suitable
155
156    @staticmethod
157    def k_mean(data, class_n=2, exclude_id=()):
158        """
159        Find k clusters in which element is close to each other.
160
161        Args:
162            data (list): Elements' information.
163            class_n (int): Number of clusters wanted to be analyzed, default is 2.
164            exclude_id (tuple[int]): The list of excluded element's index, default is ().
165
166        Returns:
167            classes (list[list[int]]): The list of clusters. Each cluster is a list of indices.
168        """
169        def _cal_mean(classes):
170            class_datas = [[data[cid] for cid in cls] for cls in classes]
171            return [sum(cls) / len(cls) if cls else float('inf') for cls in class_datas]
172
173        def _cal_distance(a, b):
174            return abs(a - b)
175
176        def _check_different(old_classes, new_classes):
177            for o, n in zip(old_classes, new_classes):
178                if o != n:
179                    return True
180            return False
181
182        if len(data) < class_n:
183            return None
184        classes = []
185        for i, _ in enumerate(data):
186            if i in exclude_id:
187                continue
188            if len(classes) >= class_n:
189                break
190            classes.append([i])
191        changed = True
192        while changed:
193            new_classes = [[] for cls in classes]
194            means = _cal_mean(classes)
195            for idx, d in enumerate(data):
196                if idx in exclude_id:
197                    continue
198                min_idx = -1
199                min_dis = float('inf')
200                for i, m in enumerate(means):
201                    cur_dis = _cal_distance(m, d)
202                    min_idx = i if min_dis > cur_dis else min_idx
203                    min_dis = cur_dis if min_dis > cur_dis else min_dis
204                new_classes[min_idx].append(idx)
205            changed = _check_different(classes, new_classes)
206            classes = new_classes
207        return classes
208
209    @staticmethod
210    def pipeline_fusion_analyze(blocks, op_sizes, exclude_id):
211        """analyze whether the segments can be pipeline optimized"""
212        # op size first, block second.
213        def _simple_factor(block, op_size):
214            return block + 5 * op_size
215
216        def _take_second(elem):
217            return elem[1]
218
219        simple_indicators = [_simple_factor(b, s)
220                             for b, s in zip(blocks, op_sizes)]
221        # 2 classes, one heavy, the other light
222        classes = ScheduleAnalyzer.k_mean(simple_indicators, 2, exclude_id)
223        if not classes:
224            return []
225        means = [sum([simple_indicators[idx] for idx in cls]) /
226                 len(cls) if cls else float('inf') for cls in classes]
227
228        # The target two clusters should be a heavy one and a light one.
229        # The light one maybe suitable to run with pipeline optimized.
230        classes_infos = [[cls, m] for cls, m in zip(classes, means)]
231        classes_infos.sort(key=_take_second)
232        pipeline_target = None
233        for ci in classes_infos:
234            if ci:
235                pipeline_target = ci
236                break
237        pipeline_gids, pipeline_mean = pipeline_target
238        if pipeline_mean > _simple_factor(float(ScheduleAnalyzer.MAX_SM) / len(blocks),
239                                          ScheduleAnalyzer.PIPELINE_OP_THREADHOLD):
240            return []
241
242        pipeline_blocks = []
243        pipeline_weight = len(pipeline_gids)
244        # Try to make two paralleled at least.
245        if pipeline_weight > 3 and pipeline_weight > len(blocks) / 2:
246            if len(pipeline_gids[:pipeline_weight // 2]) > 1:
247                pipeline_blocks.append(pipeline_gids[:pipeline_weight // 2])
248            if len(pipeline_gids[pipeline_weight // 2:]) > 1:
249                pipeline_blocks.append(pipeline_gids[pipeline_weight // 2:])
250        elif pipeline_weight > 1:
251            pipeline_blocks.append(pipeline_gids)
252        return pipeline_blocks
253
254    @staticmethod
255    def fusion_consult(blocks, op_sizes, exclude_gid):
256        """get a recommendation for parallel fusion"""
257        # Default is block fusion
258        fusion_type = "block_fusion"
259        type_info = None
260
261        activate_pipeline_optimization = False  # Disable pipeline optimization for now.
262        if activate_pipeline_optimization:
263            pipeline_info = ScheduleAnalyzer.pipeline_fusion_analyze(
264                blocks, op_sizes, exclude_gid)
265            if pipeline_info:
266                fusion_type = "block_pipeline_fusion"
267                type_info = pipeline_info
268
269        return fusion_type, type_info
270
271
272def block_parallel_estimate(graphs):
273    """estimate block parallel gain"""
274    sum_block, max_weight, sum_weight, blocks, op_sizes, exclude_gid = 0, 0, 0, [], [], []
275    for gid, g in enumerate(graphs):
276        s = ScheduleAnalyzer(g)
277        s.analyze()
278        sum_block += s.block_num
279        if s.block_weight > max_weight:
280            max_weight = s.block_weight
281        sum_weight += s.block_weight
282        blocks.append(s.block_num)
283        op_sizes.append(len(s.ops))
284        if not s.suitable_to_pipeline():
285            exclude_gid.append(gid)
286    if sum_block > ScheduleAnalyzer.MAX_SM * 32:
287        return ParalGain("none", sum_weight, 0, [0 for _ in graphs], None)
288
289    fusion_type, type_info = ScheduleAnalyzer.fusion_consult(blocks, op_sizes, tuple(exclude_gid))
290    return ParalGain(fusion_type, max_weight, sum_weight - max_weight, blocks, type_info)
291
292
293def parallel_estimate(graphs):
294    """Estimate parallel gain"""
295    return block_parallel_estimate(graphs)
296