1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# =========================================================================== 15"""GraphKernel cost model""" 16 17 18class Utils: 19 """Model utils""" 20 21 def __init__(self): 22 pass 23 24 @staticmethod 25 def get_attr_type(attr): 26 """Get attr type""" 27 if isinstance(attr, bool): 28 return 'bool' 29 if isinstance(attr, str): 30 return 'str' 31 if isinstance(attr, int): 32 return 'int' 33 if isinstance(attr, float): 34 return 'bool' 35 if isinstance(attr, (list, tuple)): 36 if not attr: 37 raise ValueError("Length of attr is 0") 38 if isinstance(attr[0], int): 39 return 'listInt' 40 if isinstance(attr[0], str): 41 return 'listStr' 42 raise ValueError("Unknown type of attr: {}".format(attr)) 43 44 45class DataFormat: 46 """DataFormat""" 47 DEFAULT = "DefaultFormat" 48 NC1KHKWHWC0 = "NC1KHKWHWC0" 49 ND = "ND" 50 NCHW = "NCHW" 51 NHWC = "NHWC" 52 HWCN = "HWCN" 53 NC1HWC0 = "NC1HWC0" 54 FRAC_Z = "FracZ" 55 FRAC_NZ = "FRACTAL_NZ" 56 C1HWNCOC0 = "C1HWNCoC0" 57 NC1HWC0_C04 = "NC1HWC0_C04" 58 FRACTAL_Z_C04 = "FRACTAL_Z_C04" 59 NDHWC = "NDHWC" 60 61 def __init__(self): 62 pass 63 64 65class DataType: 66 """Data Type""" 67 FLOAT = "float" 68 FLOAT16 = "float16" 69 FLOAT32 = "float32" 70 FLOAT64 = "float64" 71 INT = "int" 72 INT8 = "int8" 73 INT16 = "int16" 74 INT32 = "int32" 75 INT64 = "int64" 76 UINT = "uint" 77 UINT8 = "uint8" 78 UINT16 = "uint16" 79 UINT32 = "uint32" 80 UINT64 = "uint64" 81 BOOL = "bool" 82 83 def __init__(self): 84 pass 85 86 87class PrimLib: 88 """Prim lib""" 89 90 UNKNOWN = 0 91 RESHAPE = 1 92 ELEMWISE = 2 93 BROADCAST = 3 94 REDUCE = 4 95 OPAQUE = 5 96 97 def __init__(self): 98 pass 99 100 class Prim: 101 """Prim""" 102 103 def __init__(self, iter_type, calibrate=1, relation_func=None): 104 self.iter_type = iter_type 105 self.calibrate = calibrate 106 self.relation_func = relation_func 107 if relation_func is None: 108 self.relation_func = lambda *x: self.default_relation_func[iter_type](self, *x) 109 110 def default_reshape_relation(self, op, input_idx): 111 """Process reshape relation""" 112 axis_relation, elem_relation = self.unknown_relation(op, input_idx) 113 elem_relation = [PrimLib.RESHAPE] * len(elem_relation) 114 return axis_relation, elem_relation 115 116 def default_elemwise_broadcast_relation(self, op, input_idx): 117 """Process elemwise and broadcast relation""" 118 out_shape = op.output.shape 119 in_shape = op.inputs[input_idx].shape 120 if len(out_shape) < len(in_shape): 121 raise ValueError("input/output size is abnormal") 122 axis_relation, elem_relation = [], [] 123 delta = len(out_shape) - len(in_shape) 124 if delta > 0: 125 for i in range(0, delta): 126 axis_relation.append(None) 127 elem_relation.append(None) 128 for i, _ in enumerate(in_shape): 129 axis_relation.append(i) 130 elem_relation.append( 131 PrimLib.ELEMWISE if out_shape[i + delta] == in_shape[i] else PrimLib.BROADCAST) 132 return axis_relation, elem_relation 133 134 def default_reduce_relation(self, op, input_idx): 135 """Process reduce relation""" 136 axis_relation, elem_relation = self.default_elemwise_broadcast_relation(op, input_idx) 137 for i in op.attrs['reduce_axis']: 138 elem_relation[i] = PrimLib.REDUCE 139 return axis_relation, elem_relation 140 141 def unknown_relation(self, op, input_idx): 142 """Process unknown relation""" 143 out_shape = op.output.shape 144 in_shape = op.inputs[input_idx].shape 145 all_relation = list(range(len(in_shape))) 146 axis_relation = [all_relation for i in range(0, len(out_shape))] 147 elem_relation = [PrimLib.UNKNOWN for i in range(0, len(out_shape))] 148 return axis_relation, elem_relation 149 150 default_relation_func = [ 151 unknown_relation, 152 default_reshape_relation, 153 default_elemwise_broadcast_relation, 154 default_elemwise_broadcast_relation, 155 default_reduce_relation, 156 unknown_relation, 157 ] 158 159 primtives = { 160 'Add': Prim(ELEMWISE), 161 'Abs': Prim(ELEMWISE), 162 'Neg': Prim(ELEMWISE), 163 'Mul': Prim(ELEMWISE), 164 'Sub': Prim(ELEMWISE), 165 'Log': Prim(ELEMWISE), 166 'IsNan': Prim(ELEMWISE), 167 'IsInf': Prim(ELEMWISE), 168 'IsFinite': Prim(ELEMWISE), 169 'Exp': Prim(ELEMWISE), 170 'Rsqrt': Prim(ELEMWISE), 171 'Sqrt': Prim(ELEMWISE), 172 'Div': Prim(ELEMWISE), 173 'FloorDiv': Prim(ELEMWISE), 174 'RealDiv': Prim(ELEMWISE), 175 'Mod': Prim(ELEMWISE), 176 'Floor': Prim(ELEMWISE), 177 'FloorMod': Prim(ELEMWISE), 178 'Erf': Prim(ELEMWISE), 179 'Erfc': Prim(ELEMWISE), 180 'Cast': Prim(ELEMWISE), 181 'Pow': Prim(ELEMWISE), 182 'Minimum': Prim(ELEMWISE), 183 'Maximum': Prim(ELEMWISE), 184 'Reciprocal': Prim(ELEMWISE), 185 'Equal': Prim(ELEMWISE), 186 'NotEqual': Prim(ELEMWISE), 187 'Greater': Prim(ELEMWISE), 188 'GreaterEqual': Prim(ELEMWISE), 189 'Less': Prim(ELEMWISE), 190 'LessEqual': Prim(ELEMWISE), 191 'LogicalNot': Prim(ELEMWISE), 192 'LogicalAnd': Prim(ELEMWISE), 193 'LogicalOr': Prim(ELEMWISE), 194 'Square': Prim(ELEMWISE), 195 'AddN': Prim(ELEMWISE), 196 'Select': Prim(ELEMWISE, 8), 197 'ReduceSum': Prim(REDUCE), 198 'ReduceMax': Prim(REDUCE), 199 'ReduceMin': Prim(REDUCE), 200 'Argmax': Prim(REDUCE), 201 'Argmin': Prim(REDUCE), 202 'Assign': Prim(ELEMWISE), 203 'Sign': Prim(ELEMWISE), 204 'Sin': Prim(ELEMWISE), 205 'Cos': Prim(ELEMWISE), 206 'Asin': Prim(ELEMWISE), 207 'ACos': Prim(ELEMWISE), 208 'Tanh': Prim(ELEMWISE), 209 'Asinh': Prim(ELEMWISE), 210 'Acosh': Prim(ELEMWISE), 211 'InplaceAssign': Prim(ELEMWISE), 212 '@ReduceInit': Prim(ELEMWISE), 213 'Reshape': Prim(RESHAPE), 214 'Squeeze': Prim(RESHAPE), 215 'Flatten': Prim(RESHAPE), 216 'FlattenGrad': Prim(RESHAPE), 217 'Transpose': Prim(OPAQUE), 218 'Tile': Prim(BROADCAST), 219 'BroadcastTo': Prim(BROADCAST), 220 'StridedSlice': Prim(OPAQUE), 221 'MatMul': Prim(OPAQUE), 222 'TransData': Prim(OPAQUE), 223 'BatchMatMul': Prim(OPAQUE), 224 'UnPadAkg': Prim(OPAQUE), 225 'PadAkg': Prim(OPAQUE), 226 'Conv2D': Prim(OPAQUE), 227 'CReal': Prim(ELEMWISE), 228 'CImag': Prim(ELEMWISE), 229 'Complex': Prim(ELEMWISE), 230 'Atan': Prim(ELEMWISE), 231 'Atan2': Prim(ELEMWISE), 232 'Expm1': Prim(ELEMWISE), 233 'TensorScatterAdd': Prim(OPAQUE), 234 'Gather': Prim(OPAQUE), 235 'GatherNd': Prim(OPAQUE), 236 'UnsortedSegmentSum': Prim(OPAQUE), 237 'StandardNormal': Prim(OPAQUE), 238 'UserDefined': Prim(OPAQUE), 239 } 240 241 default_primtive = Prim(UNKNOWN) 242 243 @classmethod 244 def get_prim(cls, op): 245 """Get op primtive""" 246 prim = cls.primtives.get(op.prim, None) 247 if prim is None: 248 print('[WARN] primtive is not registered: ' + op.prim) 249 prim = cls.default_primtive 250 return prim 251 252 @classmethod 253 def input_relation(cls, op, input_idx): 254 """Get op's input_relation according to input_idx""" 255 return cls.get_prim(op).relation_func(op, input_idx) 256 257 @classmethod 258 def iter_type(cls, op): 259 """Get op's iter type""" 260 return cls.get_prim(op).iter_type 261 262 @classmethod 263 def is_reduce(cls, op): 264 """Check whether op's iter type is reduce""" 265 return cls.get_prim(op).iter_type == cls.REDUCE 266 267 @classmethod 268 def calibrate_iter_size(cls, op, iter_size): 269 """Get calibrate_iter_size""" 270 return cls.get_prim(op).calibrate * iter_size 271 272 @classmethod 273 def dtype_bytes(cls, dtype): 274 """Get dtype bytes""" 275 bits, unit = 1, 1 276 for i in range(len(dtype) - 1, 0, -1): 277 if dtype[i].isdecimal(): 278 bits += int(dtype[i]) * unit 279 unit *= 10 280 else: 281 break 282 return bits // 8 283 284 @classmethod 285 def inplace_reuse(cls, op, input_idx, start_axis=0): 286 """Check whether op is inplace reuse""" 287 if cls.dtype_bytes(op.output.dtype) > cls.dtype_bytes(op.inputs[input_idx].dtype): 288 return False 289 _, elem_relation = cls.get_prim(op).relation_func(op, input_idx) 290 for i in range(start_axis, len(elem_relation)): 291 if elem_relation[i] != cls.ELEMWISE: 292 return False 293 return True 294 295 296class Tensor: 297 """Tensor""" 298 299 PARA_NONE = 0 300 PARA_INPUT = 1 301 PARA_OUTPUT = 2 302 303 class Buddy: 304 """Buddy""" 305 306 def __init__(self, leader): 307 self.members = [leader] 308 309 def __init__(self, name, shape, dtype, data_format=DataFormat.DEFAULT, para_type=0): 310 self.name = name 311 self.shape = shape 312 self.dtype = dtype 313 self.data_format = data_format 314 self.para_type = para_type 315 self.op = None 316 self.to_ops = [] 317 self.buddy = None 318 319 def __str__(self): 320 return self.name + str(list(self.shape)) 321 322 def __repr__(self): 323 return "%s.%s%s" % (self.name, self.dtype, str(list(self.shape))) 324 325 def get_size(self): 326 """Get size""" 327 size = PrimLib.dtype_bytes(self.dtype) 328 for i in self.shape: 329 size *= i 330 return size 331 332 def add_buddy(self, tensor): 333 """Add buddy""" 334 if self.buddy is None: 335 self.buddy = self.Buddy(self) 336 self.buddy.members.append(tensor) 337 tensor.buddy = self.buddy 338 339 340class Value: 341 """Value""" 342 343 def __init__(self, name, dtype, value, data_format=DataFormat.DEFAULT): 344 self.name = name 345 self.shape = [1] 346 self.dtype = dtype 347 self.value = value 348 self.data_format = data_format 349 350 def __str__(self): 351 return self.name + str(list(self.shape)) 352 353 def __repr__(self): 354 return "%s.%s%s" % (self.name, self.dtype, str(list(self.shape))) 355 356 def get_size(self): 357 """Get size""" 358 return 1 359 360 361class Operator: 362 """Operator""" 363 364 def __init__(self, primtive, inputs, output, attrs): 365 self.prim = primtive 366 self.inputs = inputs 367 self.output = output 368 self.attrs = attrs 369 for t in inputs: 370 t.to_ops.append(self) 371 if output.op is None: 372 output.op = self 373 self.all_inputs = [] # include Tensor inputs and Value inputs. 374 375 def __str__(self): 376 args = ', '.join([str(t) for t in self.all_inputs]) 377 expr = "%s = %s.%s(%s) id:%s" % ( 378 str(self.output), self.prim, self.output.dtype, args, id(self)) 379 return expr if not self.attrs else '%s // %s' % (expr, str(self.attrs)) 380 381 def __repr__(self): 382 return str(self) 383 384 385class Graph: 386 """Graph""" 387 388 def __init__(self, name, ops, stitch_info=None, recompute_ops=None): 389 self.name = name 390 self.ops = ops # in topo order, can not use set 391 self.inputs = [] 392 self.outputs = [] 393 self.stitch_info = stitch_info 394 self.recompute_ops = recompute_ops 395 self.processor = "" 396 397 def set_processor(self, processor): 398 """Set processor""" 399 self.processor = processor 400 401 def add(self, ops): 402 """Add ops""" 403 if isinstance(ops, Operator): 404 self.ops.append(ops) 405 else: 406 self.ops.extend(ops) 407 408 def extract_subgraph(self, graph_name, tensor_names, difference=False): 409 """Extract subgraph from this graph""" 410 graph = Graph(graph_name, []) 411 outputs = set(tensor_names) 412 if difference: 413 for op in self.ops: 414 if op.output.name not in outputs: 415 graph.add(op) 416 else: 417 for op in self.ops: 418 if op.output.name in outputs: 419 graph.add(op) 420 outputs.remove(op.output.name) 421 for name in outputs: 422 raise ValueError("invalid input tensor : " + name) 423 return graph 424 425 def deduce_parameters(self): 426 """Deduce parameters""" 427 inputs, outputs = [], [] 428 for op in self.ops: 429 for t in op.inputs: 430 if t not in inputs and t.op not in self.ops: 431 inputs.append(t) 432 if op.output in outputs: 433 continue 434 if op.output.para_type == Tensor.PARA_OUTPUT or not op.output.to_ops: 435 outputs.append(op.output) 436 continue 437 if any([succ not in self.ops for succ in op.output.to_ops]): 438 outputs.append(op.output) 439 if self.inputs: 440 inputs = self.inputs 441 442 if self.outputs: 443 outputs = self.outputs 444 return inputs, outputs 445 446 def __str__(self): 447 inputs, outputs = self.deduce_parameters() 448 para_str = ', '.join([repr(t) for t in inputs]) 449 out_str = ', '.join([repr(t) for t in outputs]) 450 lines = [] 451 lines.append("%s(%s) -> %s {" % (self.name, para_str, out_str)) 452 if self.stitch_info: 453 if self.stitch_info.stitch_ops: 454 lines.append(' stitch -> ' + str(self.stitch_info.stitch_ops)) 455 if self.stitch_info.stitch_atomic_ops: 456 lines.append(' stitch_atomic_ops-> ' + str(self.stitch_info.stitch_atomic_ops)) 457 458 for op in self.ops: 459 lines.append(' ' + str(op)) 460 lines.append('}') 461 return '\n'.join(lines) 462 463 def __repr__(self): 464 return str(self) 465 466 def dump(self): 467 """Dump Graph to json""" 468 attr_name = {'reduce_axis': 'axis'} 469 inputs, outputs = self.deduce_parameters() 470 input_desc, output_desc, op_desc = [], [], [] 471 for t in inputs: 472 input_desc.append([{'data_type': t.dtype, 'shape': t.shape, 473 'tensor_name': t.name, 'format': t.data_format}]) 474 for t in outputs: 475 output_desc.append({'data_type': t.dtype, 'shape': t.shape, 476 'tensor_name': t.name, 'format': t.data_format}) 477 for op in self.ops: 478 attrs, in_desc = [], [] 479 for a in op.attrs: 480 name = attr_name.get(a, a) 481 attrs.append( 482 {'name': name, 'value': op.attrs[a], 'data_type': Utils.get_attr_type(op.attrs[a])}) 483 for t in op.all_inputs: 484 if isinstance(t, Tensor): 485 in_desc.append([{'data_type': t.dtype, 'name': '', 'shape': t.shape, 486 'tensor_name': t.name, 'format': t.data_format}]) 487 else: 488 in_desc.append([{'data_type': t.dtype, 'value': t.value, 'name': '', 'shape': t.shape, 489 'tensor_name': t.name, 'format': t.data_format}]) 490 out_desc = [{'data_type': op.output.dtype, 'name': '', 'shape': op.output.shape, 491 'tensor_name': op.output.name, 'format': op.output.data_format}] 492 op_desc.append({'attr': attrs, 'impl_path': '', 493 'input_desc': in_desc, 'name': op.prim, 'output_desc': out_desc}) 494 495 graph_desc = {'composite': True, 'composite_graph': '', 'id': 0, 496 'input_desc': input_desc, 'op': self.name, 'op_desc': op_desc, 'output_desc': output_desc, 497 'platform': 'AKG', 'process': self.processor} 498 499 if self.stitch_info and self.stitch_info.stitch_ops: 500 buffer_stitch = {'stitch_op': list(self.stitch_info.stitch_ops)} 501 if self.stitch_info.stitch_atomic_ops: 502 buffer_stitch['stitch_atomic_op'] = list(self.stitch_info.stitch_atomic_ops) 503 graph_desc['buffer_stitch'] = buffer_stitch 504 505 return graph_desc 506 507 508class GraphVisitor: 509 """Graph visitor""" 510 511 def __init__(self, forward=True): 512 self.forward = forward 513 514 def visit_graph(self, graph): 515 """Visit graph""" 516 if self.forward: 517 for op in graph.ops: 518 self.visit(op) 519 else: 520 for i in range(len(graph.ops)-1, -1, -1): 521 self.visit(graph.ops[i]) 522 523 524class AlignShape(GraphVisitor): 525 """Align shape""" 526 527 def __init__(self): 528 super(AlignShape, self).__init__() 529 530 def visit(self, op): 531 """Visit op node""" 532 prim = PrimLib.get_prim(op) 533 if prim.iter_type in (PrimLib.ELEMWISE, PrimLib.BROADCAST, PrimLib.REDUCE): 534 out_dim = len(op.output.shape) 535 align_dim = out_dim 536 for t in op.inputs: 537 if len(t.shape) > align_dim: 538 align_dim = len(t.shape) 539 if align_dim > out_dim: 540 op.output.shape = [1] * (align_dim - out_dim) + op.output.shape 541 542 543class AddControlBuddy(GraphVisitor): 544 """Add control buddy""" 545 546 def __init__(self): 547 super(AddControlBuddy, self).__init__() 548 self.buddies = {} # {op : [ctrl_op]} 549 550 def visit(self, op): 551 """Visit op node""" 552 if op.prim == "MakeTuple": 553 if len(op.output.to_ops) != 1: 554 raise ValueError("operator's output size is abnormal") 555 owner = op.output.to_ops[0] 556 if owner in self.buddies: 557 self.buddies[owner].append(op) 558 else: 559 self.buddies[owner] = [op] 560 if op in self.buddies: 561 ops = self.buddies.pop(op) 562 self.buddies[owner].extend(ops) 563 564 def visit_graph(self, graph): 565 """Visit graph nodes""" 566 super(AddControlBuddy, self).visit_graph(graph) 567 for owner in self.buddies: 568 for op in self.buddies[owner]: 569 owner.add_buddy(op.output) 570 571 572class GraphKernelUnsupportedException(Exception): 573 """"GraphKernel Unsupported Exception""" 574 575 def __init__(self, message): 576 super(GraphKernelUnsupportedException, self).__init__() 577 self.message = message 578