1# Copyright 2020-2021 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16"""primitive""" 17import inspect 18import copy 19from mindspore.common.api import _wrap_func 20from mindspore import context, log as logger 21from mindspore.parallel._utils import _is_in_auto_parallel_mode 22from .._c_expression import Primitive_, real_run_op, prim_type 23from .._checkparam import Validator 24from . import signature as sig 25 26 27class Primitive(Primitive_): 28 """ 29 Primitive is the base class of operator primitives in python. 30 31 Args: 32 name (str): Name for the current Primitive. 33 34 Examples: 35 >>> add = Primitive('add') 36 >>> 37 >>> # or work with prim_attr_register: 38 >>> # init a Primitive class with attr1 and attr2 39 >>> class Add(Primitive): 40 ... @prim_attr_register 41 ... def __init__(self, attr1, attr2): 42 ... '''init for add''' 43 ... # check attr1 and attr2 or do some initializations 44 ... # init a Primitive obj with attr1=1 and attr2=2 45 >>> add = Add(attr1=1, attr2=2) 46 """ 47 _repr_ignore_list = ['input_names', 'output_names'] 48 49 def __init__(self, name): 50 self.name = name 51 self.attrs = {} 52 self.init_attrs = {"name": name} 53 self._update_parameter = False 54 Primitive_.__init__(self, name) 55 if hasattr(self.__class__, '__mindspore_signature__'): 56 out = self._fill_signature(self.__class__.__mindspore_signature__) 57 self.set_signatures(out) 58 59 def _fill_signature(self, signatures): 60 """fills signature.""" 61 signatures_new = [] 62 for signature in signatures: 63 if isinstance(signature, sig.Signature): 64 signatures_new.append(signature) 65 elif isinstance(signature, sig.sig_dtype): 66 signatures_new.append(sig.make_sig(dtype=signature)) 67 else: 68 if len(signature) < 3: 69 raise ValueError(f"[Internal Error]Signature for one parameter len must > 3, but {signature}") 70 signatures_new.append(sig.make_sig(*signature)) 71 return tuple(signatures_new) 72 73 def _clone(self): 74 """ 75 Deeply clones the primitive object. 76 77 Calls the __init__() method with the same arguments. This method is called in parser if the 78 flag self.__setattr_flag__ is True. 79 """ 80 cloned = copy.deepcopy(self) 81 init_params = inspect.getfullargspec(cloned.__init__.decorated_func).args[1:] 82 init_args = {} 83 for name in init_params: 84 value = self.attrs[name] 85 init_args[name] = value 86 # __init__ should be called to construct cpp object. 87 cloned.__init__(**init_args) 88 for name in self.attrs: 89 value = self.attrs[name] 90 cloned.add_prim_attr(name, value) 91 if hasattr(self, 'instance_name'): 92 cloned.set_prim_instance_name(self.instance_name) 93 return cloned 94 95 def add_prim_attr(self, name, value): 96 """ 97 Add primitive attribute. 98 99 Args: 100 name (str): Attribute Name. 101 value (Any): Attribute value. 102 103 Examples: 104 >>> import mindspore.ops as ops 105 >>> a = ops.Add() 106 >>> a = a.add_prim_attr("attr",1) 107 >>> out = a.attrs["attr"] 108 >>> print(out) 109 1 110 """ 111 self.__dict__[name] = value 112 self.attrs[name] = value 113 self.add_attr(name, value) 114 return self 115 116 def del_prim_attr(self, name): 117 """ 118 Delete primitive attribute. 119 120 Args: 121 name (str): Attribute Name. 122 Examples: 123 >>> import mindspore.ops as ops 124 >>> a = ops.Add() 125 >>> a = a.add_prim_attr("attr",1) 126 >>> a = a.del_prim_attr("attr") 127 >>> print(a.attrs) 128 {'input_names': ['x', 'y'], 'output_names' : ['output']} 129 """ 130 if name in self.__dict__ and name in self.attrs: 131 del self.__dict__[name] 132 del self.attrs[name] 133 self.del_attr(name) 134 return self 135 136 def set_stage(self, stage): 137 """ 138 Add stage id to primitive attribute. 139 140 Note: 141 It is valid only in semi auto parallel. 142 In other parallel modes, please set it to be 0. 143 Args: 144 stage (int): The stage id for the current operation. 145 Examples: 146 >>> from mindspore import ops 147 >>> add = ops.Add() 148 >>> print(add.set_stage(0)) 149 Prim[Add]<stage=0> 150 """ 151 self.add_prim_attr("stage", stage) 152 return self 153 154 def shard(self, strategy): 155 """ 156 Add strategies to primitive attribute. 157 158 Note: 159 It is valid only in semi auto parallel or auto parallel mode. 160 In other parallel modes, strategies set here will be ignored. 161 162 Args: 163 strategy (tuple): Strategy describes the distributed parallel mode of the current primitive. 164 Examples: 165 >>> from mindspore import ops 166 >>> add = ops.Add() 167 >>> print(add.shard(((1, 1), (1, 1)))) 168 Prim[Add]<strategy=((1, 1), (1, 1))> 169 """ 170 mode = context.get_auto_parallel_context("parallel_mode") 171 if strategy is not None: 172 if not isinstance(strategy, tuple): 173 raise TypeError(f'strategy must be tuple type, but got:{type(strategy)}') 174 for ele in strategy: 175 if not isinstance(ele, tuple): 176 raise TypeError(f'The element of strategy must be tuple type, but got:{type(ele)}') 177 if not _is_in_auto_parallel_mode() and strategy: 178 logger.warning(f"The shard strategy {strategy} of {self.name} is not valid in {mode}. " 179 f"Please use semi auto or auto parallel mode.") 180 self.add_prim_attr("strategy", strategy) 181 return self 182 183 def set_prim_instance_name(self, instance_name): 184 """ 185 Set instance name to primitive operator. 186 187 Note: 188 It will be called by default when user defines primitive operator. 189 190 Args: 191 instance_name (str): Instance name of primitive operator set by user. 192 Examples: 193 >>> import mindspore.ops as ops 194 >>> a = ops.Add() 195 >>> a.set_prim_instance_name("add") 196 >>> print(a.instance_name) 197 add 198 """ 199 self.set_instance_name(instance_name) 200 self.instance_name = instance_name 201 return self 202 203 def __getattr__(self, item): 204 if item == 'infer_dynamic_shape': 205 return None 206 if item in super().get_attr_dict(): 207 return super().get_attr_dict()[item] 208 if item in self.attrs: 209 return self.attrs[item] 210 raise AttributeError(item) 211 212 def check_elim(self, *args): 213 """ 214 Check if the primitive can be eliminated. Subclass in need should override this method. 215 216 Args: 217 args(Primitive args): Same as arguments of current Primitive. 218 219 Returns: 220 A tuple consisting of two elements. 221 The first element means if the primitive can be calculated in compiling stage, 222 the second element is calculated result. 223 224 Examples: 225 >>> class AddN(Primitive): 226 ... @prim_attr_register 227 ... def __init__(self): 228 ... self.init_prim_io_names(inputs=["inputs"], outputs=["sum"]) 229 ... def check_elim(self, inputs): 230 ... if len(inputs) != 1: 231 ... return (False, None) 232 ... if isinstance(inputs[0], Tensor): 233 ... return (True, inputs[0]) 234 ... 235 >>> addn = AddN() 236 >>> input_x = Tensor(np.array([1, 2, 3]), mindspore.float32) 237 >>> output = addn.check_elim((input_x,)) 238 >>> print(output) 239 (True, Tensor(shape=[3], dtype=Float32, value= [ 1.00000000e+00, 2.00000000e+00, 3.00000000e+00])) 240 """ 241 return (False, None) 242 243 def __call__(self, *args): 244 should_elim, output = self.check_elim(*args) 245 if should_elim: 246 return output 247 return _run_op(self, self.name, args) 248 249 def __getstate__(self): 250 return self.__dict__ 251 252 def __setstate__(self, d): 253 self.__dict__.update(d) 254 255 def __deepcopy__(self, memo): 256 return type(self)(**self.init_attrs) 257 258 def __repr__(self): 259 attr = ', '.join([f'{k}={self.attrs[k]}' for k in self.attrs if not k in Primitive._repr_ignore_list]) 260 info_str = f'Prim[{self.name}]' 261 if attr: 262 info_str += f'<{attr}>' 263 return info_str 264 265 def init_prim_io_names(self, inputs, outputs): 266 """ 267 Initialize the name of inputs and outputs of Tensor or attributes. 268 269 Args: 270 inputs (list[str]): list of inputs names. 271 outputs (list[str]): list of outputs names. 272 Examples: 273 >>> import mindspore.ops as ops 274 >>> a = ops.Add() 275 >>> a.init_prim_io_names(["x","y"],["sum"]) 276 >>> print(a.input_names) 277 ['x','y'] 278 >>> print(a.output_names) 279 ['sum'] 280 """ 281 # for checking para names with kernel implementation 282 self.add_prim_attr("input_names", inputs) 283 # for checking output number with kernel implementation 284 self.add_prim_attr("output_names", outputs) 285 286 @property 287 def update_parameter(self): 288 """Return whether the primitive will update the value of parameter.""" 289 return self._update_parameter 290 291 def recompute(self, mode=True): 292 """ 293 Set the primitive recomputed. If a primitive set recomputed feeds into some backward nodes 294 for computing gradient, rather than storing the intermediate activation computed in forward 295 pass, we will recompute it in backward pass. 296 297 Note: 298 299 - If the computation involves something like randomization or global variable, the equivalence 300 is not guaranteed currently. 301 - Not supported in pynative mode 302 303 Args: 304 mode (bool): Specifies whether the primitive is recomputed. Default: True. 305 Examples: 306 >>> import numpy as np 307 >>> import mindspore as ms 308 >>> from mindspore import Tensor, ops, nn 309 >>> class NetRecompute(nn.Cell): 310 ... def __init__(self): 311 ... super(NetRecompute,self).__init__() 312 ... self.relu = ops.ReLU().recompute() 313 ... self.sqrt = ops.Sqrt() 314 ... def construct(self, x): 315 ... out = self.relu(x) 316 ... return self.sqrt(out) 317 ... 318 >>> class GradNet(nn.Cell): 319 ... def __init__(self, network): 320 ... super(GradNet,self).__init__() 321 ... self.network = network 322 ... self.grad = ops.GradOperation() 323 ... def construct(self, x): 324 ... g_out = self.grad(self.network)(x) 325 ... return g_out 326 ... 327 >>> x = Tensor(np.array([-1,1]).astype(np.float32)) 328 >>> net = NetRecompute() 329 >>> grad = GradNet(net) 330 >>> a = grad(x) 331 >>> print(a) 332 [0. 0.5] 333 """ 334 if context.get_context("mode") == context.PYNATIVE_MODE: 335 raise TypeError("Recompute is not supported in pynative mode currently.") 336 Validator.check_bool(mode) 337 self.add_prim_attr("recompute", mode) 338 return self 339 340 341class PrimitiveWithCheck(Primitive): 342 """ 343 PrimitiveWithCheck is the base class of primitives in python defines functions for checking operator 344 input arguments but used the infer method registered in c++ source codes. 345 346 There are three methods can be override to define the check logic of the primitive: __check__(), check_shape(), 347 check_dtype(). If __check__() is defined in primitive, the __check__() has highest priority to be called. 348 If __check__() is not defined, check_shape() and check_dtype() can be defined to describe the check logic of 349 the shape and type. Method infer_value() can also be defined (such as PrimitiveWithInfer) for constant propagation. 350 351 Args: 352 name (str): Name of the current Primitive. 353 354 Supported Platforms: 355 ``Ascend`` ``GPU`` ``CPU`` 356 357 Examples: 358 >>> # init a Primitive class with check 359 >>> class Flatten(PrimitiveWithCheck): 360 ... @prim_attr_register 361 ... def __init__(self): 362 ... pass 363 ... def check_shape(self, input_x): 364 ... validator.check_int(len(input_x), 1, Rel.GE, 'input_x rank', self.name) 365 ... 366 ... def check_dtype(self, input_x): 367 ... validator.check_subclass("input_x", input_x, mstype.tensor, self.name) 368 ... 369 >>> # init a Primitive obj 370 >>> add = Flatten() 371 """ 372 373 def __init__(self, name): 374 Primitive.__init__(self, name) 375 self.set_prim_type(prim_type.py_infer_check) 376 377 def _clone(self): 378 """ 379 Deeply clones the primitive object. 380 381 Calls the __init__() method with the same arguments. This method is called in parser if the 382 flag self.__setattr_flag__ is True. 383 """ 384 cloned_prim = Primitive._clone(self) 385 return cloned_prim 386 387 def check_shape(self, *args): 388 """ 389 Check shapes of input args. 390 391 Note: 392 The shape of scalar is an empty tuple. 393 394 Args: 395 args (tuple(int)): shapes of input tensors. 396 397 Return: 398 None. 399 """ 400 return None 401 402 def check_dtype(self, *args): 403 """ 404 Check data types of input args. 405 406 Args: 407 args (:class:`mindspore.dtype`): data type of inputs. 408 409 Return: 410 None. 411 """ 412 return None 413 414 def __check__(self, *args): 415 """Checking the input shape and the input type of ops is valid """ 416 tracks = ['dtype', 'shape'] 417 for track in tracks: 418 fn = getattr(self, 'check_' + track) 419 fn(*(x[track] for x in args)) 420 421 422class PrimitiveWithInfer(Primitive): 423 """ 424 PrimitiveWithInfer is the base class of primitives in python and defines functions for tracking inference 425 in python. 426 427 There are four method can be override to define the infer logic of the primitive: __infer__(), infer_shape(), 428 infer_dtype(), and infer_value(). If __infer__() is defined in primitive, the __infer__() has highest priority 429 to be called. If __infer__() is not defined, infer_shape() and infer_dtype() can be defined to describe the infer 430 logic of the shape and type. The infer_value() is used for constant propagation. 431 432 Args: 433 name (str): Name of the current Primitive. 434 435 Supported Platforms: 436 ``Ascend`` ``GPU`` ``CPU`` 437 438 Examples: 439 >>> # init a Primitive class with infer 440 >>> class Add(PrimitiveWithInfer): 441 ... @prim_attr_register 442 ... def __init__(self): 443 ... pass 444 ... 445 ... def infer_shape(self, x, y): 446 ... return x # output shape same as first input 'x' 447 ... 448 ... def infer_dtype(self, x, y): 449 ... return x # output type same as first input 'x' 450 ... 451 >>> # init a Primitive obj 452 >>> add = Add() 453 """ 454 455 def __init__(self, name): 456 Primitive.__init__(self, name) 457 self.set_prim_type(prim_type.py_infer_shape) 458 459 def _clone(self): 460 """ 461 Deeply clones the primitive object. 462 463 Calls the __init__() method with the same arguments. This method is called in parser if the 464 flag self.__setattr_flag__ is True. 465 """ 466 cloned_prim = Primitive._clone(self) 467 return cloned_prim 468 469 def infer_shape(self, *args): 470 """ 471 Infer output shape based on input shape. 472 473 Note: 474 The shape of scalar is an empty tuple. 475 476 Args: 477 args (tuple(int)): shapes of input tensors. 478 479 Return: 480 `tuple(int)`, shapes of output tensors. 481 """ 482 return None 483 484 def infer_dtype(self, *args): 485 """ 486 Infer output dtype based on input dtype. 487 488 Args: 489 args (:class:`mindspore.dtype`): data type of inputs. 490 491 Return: 492 :class:`mindspore.dtype`, data type of outputs. 493 """ 494 return None 495 496 def infer_value(self, *args): 497 """ 498 Infer output value based on input value at compile time. 499 500 Args: 501 args (Any): value of inputs. 502 503 Return: 504 Value of outputs. Return `None`, the value can not be inferred at compile time in this case. 505 """ 506 return None 507 508 def __infer__(self, *args): 509 """Infer shape, type, and value at the same time by using dictionary as arguments.""" 510 is_graph_mode = context.get_context("mode") == context.GRAPH_MODE 511 fn_infer_dynamic_shape = getattr(self, 'infer_dynamic_shape', None) 512 if is_graph_mode and fn_infer_dynamic_shape is not None: 513 out = fn_infer_dynamic_shape(*args) 514 tracks = ['dtype', 'value'] 515 for track in tracks: 516 fn = getattr(self, 'infer_' + track) 517 # fn may return None 518 out[track] = fn(*(x[track] for x in args)) 519 return out 520 521 tracks = ['dtype', 'shape', 'value'] 522 out = {} 523 for track in tracks: 524 fn = getattr(self, 'infer_' + track) 525 # fn may return None 526 out[track] = fn(*(x[track] for x in args)) 527 528 # in non-graph_mode, it is not necessary to infer min/max shape 529 if not is_graph_mode: 530 return out 531 532 # output does not contain dynamic shape, no need to calculate min/max shape 533 def has_dynamic_shape(shp): 534 if isinstance(shp, int): 535 return shp < 0 536 if isinstance(shp, (list, tuple)): 537 return any(has_dynamic_shape(e) for e in shp) 538 return False 539 540 if not has_dynamic_shape(out['shape']): 541 return out 542 543 # calculate min/max shape for output 544 def get_specified_shape(elems, attr): 545 has_specified_shape = False 546 ret_vals = [] 547 for elem in elems: 548 if attr in elem: 549 has_specified_shape = True 550 ret_vals.append(elem[attr]) 551 else: 552 ret_vals.append(elem['shape']) 553 return has_specified_shape, tuple(ret_vals) 554 555 has_min_shape, min_shapes = get_specified_shape(args, 'min_shape') 556 has_max_shape, max_shapes = get_specified_shape(args, 'max_shape') 557 if not (has_min_shape or has_max_shape): 558 return out 559 if has_min_shape and has_max_shape: 560 fn_infer_min_shape = getattr(self, 'infer_shape') 561 fn_infer_max_shape = fn_infer_min_shape 562 if hasattr(self, 'infer_min_shape'): 563 fn_infer_min_shape = getattr(self, 'infer_min_shape') 564 if hasattr(self, 'infer_max_shape'): 565 fn_infer_max_shape = getattr(self, 'infer_max_shape') 566 out['min_shape'] = fn_infer_min_shape(*min_shapes) 567 out['max_shape'] = fn_infer_max_shape(*max_shapes) 568 return out 569 raise ValueError('Input args has invalid dynamic shape, args info: {args}') 570 571 572def prim_attr_register(fn): 573 """ 574 Primitive attributes register. 575 576 Register the decorator of the built-in operator primitive '__init__'. 577 The function will add all the parameters of '__init__' as operator attributes , 578 and init primtive name. 579 580 Args: 581 fn (function): __init__ function of primitive. 582 583 Returns: 584 function, original function. 585 586 Examples: 587 >>> class MatMul(PrimitiveWithCheck): 588 ... @prim_attr_register 589 ... def __init__(self, transpose_a=False, transpose_b=False): 590 ... self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output']) 591 ... 592 >>> # init a Primitive obj 593 >>> matmul = MatMul() 594 """ 595 596 def deco(self, *args, **kwargs): 597 class_name = self.__class__.__name__ 598 if hasattr(self.__class__, "substitute_name"): 599 class_name = self.__class__.substitute_name 600 if isinstance(self, PrimitiveWithInfer): 601 PrimitiveWithInfer.__init__(self, class_name) 602 elif isinstance(self, PrimitiveWithCheck): 603 PrimitiveWithCheck.__init__(self, class_name) 604 else: 605 Primitive.__init__(self, self.__class__.__name__) 606 bound_args = inspect.signature(fn).bind(self, *args, **kwargs) 607 bound_args.apply_defaults() 608 arguments = bound_args.arguments 609 del arguments['self'] 610 del self.init_attrs['name'] 611 for name in arguments: 612 value = arguments[name] 613 self.add_prim_attr(name, value) 614 self.init_attrs[name] = value 615 fn(self, *args, **kwargs) 616 617 deco.decorated_func = fn 618 return deco 619 620 621def constexpr(fn=None, get_instance=True, name=None): 622 """ 623 Creates a PrimitiveWithInfer operator that can infer the value at compile time. We can use it to define a function 624 to compute constant value using the constants in the constructor. 625 626 Args: 627 fn (function): A `fn` use as the infer_value of the output operator. Default: None. 628 get_instance (bool): If true, return the instance of operator, 629 otherwise return the operator class. Default: True. 630 name (str): Defines the operator name. If `name` is None, use the function name as op name. Default: None. 631 632 Examples: 633 >>> from mindspore.ops import constexpr 634 >>> a = (1, 2) 635 >>> # make an operator to calculate tuple len 636 >>> @constexpr 637 >>> def tuple_len(x): 638 ... return len(x) 639 ... 640 >>> print(tuple_len(a)) 641 2 642 >>> # make an operator class to calculate tuple len 643 >>> @constexpr(get_instance=False, name="TupleLen") 644 >>> def tuple_len_class(x): 645 ... return len(x) 646 ... 647 >>> print(tuple_len_class()(a)) 648 2 649 """ 650 651 def deco(fn): 652 """Decorator for CompileOp.""" 653 654 class CompileOp(PrimitiveWithInfer): 655 """ 656 CompileOp is a temporary operator used to execute the constexpr function. 657 """ 658 659 def __init__(self): 660 op_name = name if name else fn.__name__ 661 PrimitiveWithInfer.__init__(self, op_name) 662 self.set_const_prim(True) 663 664 def infer_value(self, *args): 665 return fn(*args) 666 667 def __call__(self, *args, **kwargs): 668 return fn(*args) 669 670 if get_instance: 671 return CompileOp() 672 return CompileOp 673 674 if fn is not None: 675 return deco(fn) 676 return deco 677 678 679@_wrap_func 680def _run_op(obj, op_name, args): 681 """Single op execution function supported by ge in PyNative mode.""" 682 output = real_run_op(obj, op_name, args) 683 return output 684