1# Copyright 2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# =========================================================================== 15"""Cost model for parallel fusion""" 16from .model import PrimLib 17 18 19class ParalGain: 20 """Paral Gain""" 21 22 def __init__(self, fusion_type, bottleneck, gain, block_assign, type_info): 23 self.fusion_type = fusion_type 24 self.bottleneck = bottleneck 25 self.gain = gain 26 self.block_assign = block_assign 27 self.type_info = type_info 28 29 30class ScheduleAnalyzer: 31 """schedule analyzer""" 32 WRAP_SIZE = 32 33 MAX_SM = 80 # Volta 34 MAX_NUM_THREADS = 1024 35 MAX_BLOCK = 256 36 PIPELINE_OP_THREADHOLD = 5 37 38 def __init__(self, graph): 39 self.graph = graph 40 self.block_num = 0 41 self.block_weight = 0 42 _, outputs = graph.deduce_parameters() 43 self.ops = graph.ops 44 self.dom_op = [out.op for out in outputs] 45 46 @staticmethod 47 def prod(shape): 48 """Compute shape product""" 49 res = shape[0] 50 for i in range(1, len(shape)): 51 res = res * shape[i] 52 return res 53 54 def _cal_weight(self, ops): 55 weight = 0 56 for op in ops: 57 weight += self.prod(op.output.shape) * \ 58 PrimLib.dtype_bytes(op.output.dtype) 59 return weight 60 61 def injective_analyze(self): 62 """analyze injective case""" 63 const_size = max([self.prod(op.output.shape) for op in self.dom_op]) 64 const_size = (const_size + self.MAX_NUM_THREADS - 65 1) // self.MAX_NUM_THREADS * self.MAX_NUM_THREADS 66 67 total_weight = self._cal_weight(self.ops) 68 total_block = (const_size + self.MAX_NUM_THREADS - 69 1) // self.MAX_NUM_THREADS 70 need_block_split = const_size > self.MAX_BLOCK * self.MAX_NUM_THREADS 71 if need_block_split: 72 self.block_num = self.MAX_BLOCK 73 waves = (total_block + self.MAX_BLOCK - 1) // self.MAX_BLOCK 74 self.block_weight = total_weight // total_block * waves 75 else: 76 self.block_num = total_block 77 self.block_weight = total_weight // self.block_num 78 79 def reduce_analyze(self): 80 """analyze reduce case""" 81 thread_x, thread_y = 32, 32 82 reduce_op = None 83 for op in self.ops: 84 if PrimLib.iter_type(op) == PrimLib.REDUCE: 85 if reduce_op: 86 raise RuntimeError( 87 "Not support multiply reduce op in a graph now.") 88 reduce_op = op 89 if not reduce_op: 90 raise RuntimeError("Wrong analyze for reduce!") 91 shape = reduce_op.inputs[0].shape 92 reduce_axis = reduce_op.attrs['reduce_axis'] 93 total_space = self.prod(shape) 94 red_space = shape[reduce_axis[0]] 95 for i in range(1, len(reduce_axis)): 96 red_space *= shape[reduce_axis[i]] 97 dtype_size = PrimLib.dtype_bytes(reduce_op.output.dtype) 98 99 weight = self._cal_weight(self.ops) # reduce + injective 100 block_x = (total_space // red_space + thread_y - 1) // thread_y 101 block_w = (weight + block_x - 1) // block_x 102 waves = (block_x + self.MAX_BLOCK - 1) // self.MAX_BLOCK 103 self.block_num = min(self.MAX_BLOCK, block_x) 104 all_reduce = 10 # 1 reduce init + 3 sync + 5 bin + 1 write 105 self.block_weight = (block_w + all_reduce * 106 dtype_size * thread_x * thread_y) * waves 107 108 def default_analyze(self): 109 """analyze default case""" 110 def _cal_default_space(op): 111 space = self.prod(op.output.shape) 112 for t in op.inputs: 113 size = self.prod(t.shape) 114 if size > space: 115 space = size 116 return space 117 space = max([_cal_default_space(op) for op in self.dom_op]) 118 119 # each sm least 4 wrap 120 block = (space + (self.WRAP_SIZE * 4) - 1) // (self.WRAP_SIZE * 4) 121 self.block_num = min(self.MAX_BLOCK, block) 122 self.block_weight = self._cal_weight(self.ops) // self.block_num 123 124 def analyze(self): 125 """analyze ops""" 126 def _ops_type(ops, dom_op): 127 have_reduce = any( 128 [PrimLib.iter_type(op) == PrimLib.REDUCE for op in ops]) 129 if have_reduce: 130 return True 131 return PrimLib.iter_type(dom_op[0]) 132 133 dom_type = _ops_type(self.ops, self.dom_op) 134 if dom_type in (PrimLib.ELEMWISE, PrimLib.BROADCAST): 135 self.injective_analyze() 136 elif dom_type == PrimLib.REDUCE: 137 self.reduce_analyze() 138 else: 139 self.default_analyze() 140 141 def suitable_to_pipeline(self): 142 """judge whether is suitable to be pipeline optimized""" 143 # Reduce is not suitable 144 def _contain_reduce(ops): 145 for op in ops: 146 # Reduce may make the tiling bad. 147 if PrimLib.primtives.get(op.prim, None) == PrimLib.REDUCE: 148 return True 149 return False 150 151 suitable = True 152 if _contain_reduce(self.ops): 153 suitable = False 154 return suitable 155 156 @staticmethod 157 def k_mean(data, class_n=2, exclude_id=()): 158 """ 159 Find k clusters in which element is close to each other. 160 161 Args: 162 data (list): Elements' information. 163 class_n (int): Number of clusters wanted to be analyzed, default is 2. 164 exclude_id (tuple[int]): The list of excluded element's index, default is (). 165 166 Returns: 167 classes (list[list[int]]): The list of clusters. Each cluster is a list of indices. 168 """ 169 def _cal_mean(classes): 170 class_datas = [[data[cid] for cid in cls] for cls in classes] 171 return [sum(cls) / len(cls) if cls else float('inf') for cls in class_datas] 172 173 def _cal_distance(a, b): 174 return abs(a - b) 175 176 def _check_different(old_classes, new_classes): 177 for o, n in zip(old_classes, new_classes): 178 if o != n: 179 return True 180 return False 181 182 if len(data) < class_n: 183 return None 184 classes = [] 185 for i, _ in enumerate(data): 186 if i in exclude_id: 187 continue 188 if len(classes) >= class_n: 189 break 190 classes.append([i]) 191 changed = True 192 while changed: 193 new_classes = [[] for cls in classes] 194 means = _cal_mean(classes) 195 for idx, d in enumerate(data): 196 if idx in exclude_id: 197 continue 198 min_idx = -1 199 min_dis = float('inf') 200 for i, m in enumerate(means): 201 cur_dis = _cal_distance(m, d) 202 min_idx = i if min_dis > cur_dis else min_idx 203 min_dis = cur_dis if min_dis > cur_dis else min_dis 204 new_classes[min_idx].append(idx) 205 changed = _check_different(classes, new_classes) 206 classes = new_classes 207 return classes 208 209 @staticmethod 210 def pipeline_fusion_analyze(blocks, op_sizes, exclude_id): 211 """analyze whether the segments can be pipeline optimized""" 212 # op size first, block second. 213 def _simple_factor(block, op_size): 214 return block + 5 * op_size 215 216 def _take_second(elem): 217 return elem[1] 218 219 simple_indicators = [_simple_factor(b, s) 220 for b, s in zip(blocks, op_sizes)] 221 # 2 classes, one heavy, the other light 222 classes = ScheduleAnalyzer.k_mean(simple_indicators, 2, exclude_id) 223 if not classes: 224 return [] 225 means = [sum([simple_indicators[idx] for idx in cls]) / 226 len(cls) if cls else float('inf') for cls in classes] 227 228 # The target two clusters should be a heavy one and a light one. 229 # The light one maybe suitable to run with pipeline optimized. 230 classes_infos = [[cls, m] for cls, m in zip(classes, means)] 231 classes_infos.sort(key=_take_second) 232 pipeline_target = None 233 for ci in classes_infos: 234 if ci: 235 pipeline_target = ci 236 break 237 pipeline_gids, pipeline_mean = pipeline_target 238 if pipeline_mean > _simple_factor(float(ScheduleAnalyzer.MAX_SM) / len(blocks), 239 ScheduleAnalyzer.PIPELINE_OP_THREADHOLD): 240 return [] 241 242 pipeline_blocks = [] 243 pipeline_weight = len(pipeline_gids) 244 # Try to make two paralleled at least. 245 if pipeline_weight > 3 and pipeline_weight > len(blocks) / 2: 246 if len(pipeline_gids[:pipeline_weight // 2]) > 1: 247 pipeline_blocks.append(pipeline_gids[:pipeline_weight // 2]) 248 if len(pipeline_gids[pipeline_weight // 2:]) > 1: 249 pipeline_blocks.append(pipeline_gids[pipeline_weight // 2:]) 250 elif pipeline_weight > 1: 251 pipeline_blocks.append(pipeline_gids) 252 return pipeline_blocks 253 254 @staticmethod 255 def fusion_consult(blocks, op_sizes, exclude_gid): 256 """get a recommendation for parallel fusion""" 257 # Default is block fusion 258 fusion_type = "block_fusion" 259 type_info = None 260 261 activate_pipeline_optimization = False # Disable pipeline optimization for now. 262 if activate_pipeline_optimization: 263 pipeline_info = ScheduleAnalyzer.pipeline_fusion_analyze( 264 blocks, op_sizes, exclude_gid) 265 if pipeline_info: 266 fusion_type = "block_pipeline_fusion" 267 type_info = pipeline_info 268 269 return fusion_type, type_info 270 271 272def block_parallel_estimate(graphs): 273 """estimate block parallel gain""" 274 sum_block, max_weight, sum_weight, blocks, op_sizes, exclude_gid = 0, 0, 0, [], [], [] 275 for gid, g in enumerate(graphs): 276 s = ScheduleAnalyzer(g) 277 s.analyze() 278 sum_block += s.block_num 279 if s.block_weight > max_weight: 280 max_weight = s.block_weight 281 sum_weight += s.block_weight 282 blocks.append(s.block_num) 283 op_sizes.append(len(s.ops)) 284 if not s.suitable_to_pipeline(): 285 exclude_gid.append(gid) 286 if sum_block > ScheduleAnalyzer.MAX_SM * 32: 287 return ParalGain("none", sum_weight, 0, [0 for _ in graphs], None) 288 289 fusion_type, type_info = ScheduleAnalyzer.fusion_consult(blocks, op_sizes, tuple(exclude_gid)) 290 return ParalGain(fusion_type, max_weight, sum_weight - max_weight, blocks, type_info) 291 292 293def parallel_estimate(graphs): 294 """Estimate parallel gain""" 295 return block_parallel_estimate(graphs) 296