1# Copyright 2020-2024 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16"""primitive""" 17import functools 18import inspect 19import copy 20import numpy as np 21from mindspore.common.api import _wrap_func 22from mindspore.log import _LogActionOnce 23from mindspore import context, log as logger 24from mindspore.parallel._utils import _is_in_auto_parallel_mode, _is_in_data_parallel_mode, \ 25 _is_in_hybrid_parallel_mode, SUPPORTED_TUPLE_IN_TUPLE_STRATEGY 26from mindspore.parallel._ps_context import _is_ps_mode, _is_role_sched 27from mindspore.parallel.shard import Layout 28from mindspore.common.api import _pynative_executor 29from mindspore.common._stub_tensor import _convert_stub 30from mindspore._c_expression import Primitive_, PrimitiveFunction_, prim_type, typing 31from mindspore import _checkparam as Validator 32from mindspore.ops import signature as sig 33 34 35class Primitive(Primitive_): 36 """ 37 Primitive is the base class of operator primitives in python. 38 39 Args: 40 name (str): Name for the current Primitive. 41 42 Examples: 43 >>> from mindspore.ops import prim_attr_register, Primitive 44 >>> add = Primitive('add') 45 >>> 46 >>> # or work with prim_attr_register: 47 >>> # init a Primitive class with attr1 and attr2 48 >>> class Add(Primitive): 49 ... @prim_attr_register 50 ... def __init__(self, attr1, attr2): 51 ... '''init for add''' 52 ... # check attr1 and attr2 or do some initializations 53 ... # init a Primitive obj with attr1=1 and attr2=2 54 >>> add = Add(attr1=1, attr2=2) 55 """ 56 _repr_ignore_list = ['input_names', 'output_names'] 57 58 def __init__(self, name): 59 self.name = name 60 self.attrs = {} 61 self.init_attrs = {"name": name} 62 self._update_parameter = False 63 Primitive_.__init__(self, name) 64 if hasattr(self.__class__, '__mindspore_signature__'): 65 out = self._fill_signature(self.__class__.__mindspore_signature__) 66 self.set_signatures(out) 67 68 def add_prim_attr(self, name, value): 69 """ 70 Add primitive attribute. 71 72 Args: 73 name (str): Attribute Name. 74 value (Any): Attribute value. 75 76 Examples: 77 >>> from mindspore import ops 78 >>> a = ops.Add() 79 >>> a = a.add_prim_attr("attr",1) 80 >>> out = a.attrs["attr"] 81 >>> print(out) 82 1 83 """ 84 self.__dict__[name] = value 85 self.attrs[name] = value 86 self.add_attr(name, value) 87 return self 88 89 def _set_prim_arg(self, name, value): 90 """ 91 Set primitive initialization arguments. 92 93 Different from add_prim_attr, it is used internally to store Primitive 94 initialization arguments in Python. 95 """ 96 self.__dict__[name] = value 97 self.attrs[name] = value 98 return self 99 100 def _set_prim_arg_with_handler(self, name, value, arg_handler): 101 """ 102 Set primitive initialization arguments and with arg_handler. 103 """ 104 value = value if value is None else arg_handler(self.__class__.__name__, name, value) 105 return self._set_prim_arg(name, value) 106 107 def set_device(self, device_target): 108 """ 109 Set primitive been executed device. 110 111 Args: 112 device_target (str): The target device to run, support "Ascend", "GPU", and "CPU". 113 114 Examples: 115 >>> from mindspore import ops 116 >>> a = ops.Add() 117 >>> a = a.set_device("GPU") 118 >>> print(a.primitive_target) 119 GPU 120 """ 121 return self.add_prim_attr("primitive_target", device_target) 122 123 def _fill_signature(self, signatures): 124 """fills signature.""" 125 signatures_new = [] 126 for signature in signatures: 127 if isinstance(signature, sig.Signature): 128 signatures_new.append(signature) 129 elif isinstance(signature, sig.sig_dtype): 130 signatures_new.append(sig.make_sig(dtype=signature)) 131 else: 132 if len(signature) < 3: 133 raise ValueError(f"[Internal Error]Signature for one parameter len must > 3, but {signature}") 134 signatures_new.append(sig.make_sig(*signature)) 135 return tuple(signatures_new) 136 137 def _clone(self): 138 """ 139 Deeply clones the primitive object. 140 141 Calls the __init__() method with the same arguments. This method is called in parser if the 142 flag self.__setattr_flag__ is True. 143 """ 144 cloned = copy.deepcopy(self) 145 init_params = list() 146 if hasattr(cloned.__init__, 'decorated_func'): 147 init_params = inspect.getfullargspec(cloned.__init__.decorated_func).args[1:] 148 init_args = self.init_attrs 149 for name in init_params: 150 value = self.attrs[name] 151 init_args[name] = value 152 # __init__ should be called to construct cpp object. 153 cloned.__init__(**init_args) 154 for name in self.attrs: 155 value = self.attrs[name] 156 cloned.add_prim_attr(name, value) 157 if hasattr(self, 'instance_name'): 158 cloned.set_prim_instance_name(self.instance_name) 159 return cloned 160 161 def _check_shard_strategy(self, strategy, log_info): 162 """Check shard strategy is validate or not""" 163 is_layout = [] 164 if not isinstance(strategy, tuple): 165 raise TypeError(f'{log_info} must be tuple type, but got:{type(strategy)}') 166 for in_ele in strategy: 167 if not isinstance(in_ele, tuple) and not isinstance(in_ele, Layout): 168 raise TypeError(f'The element of strategy must be tuple/Layout type, but got:{type(in_ele)}') 169 if isinstance(in_ele, tuple): 170 for in_value in in_ele: 171 if not isinstance(in_value, int) and self.name not in SUPPORTED_TUPLE_IN_TUPLE_STRATEGY: 172 raise TypeError(f'The {log_info}: {strategy} of {self.name} is not valid,' 173 f' the value of strategy must be int type, but got:{type(in_value)}') 174 is_layout.append(False) 175 continue 176 is_layout.append(True) 177 if not is_layout: 178 np_is_layout = np.array(is_layout) 179 if not (np_is_layout == np_is_layout[0]).all(): 180 raise TypeError(f'{log_info} item must be all tuple type or all Layout type.') 181 return np.array(is_layout) 182 183 def _extract_layout_value(self, layout, log_info): 184 """Extract parallel layout value""" 185 layout_value = None 186 if layout is not None: 187 if not isinstance(layout, tuple): 188 raise TypeError(f'{log_info} must be tuple type, but got:{type(layout)}') 189 layout_value = () 190 for in_ele in layout: 191 if not isinstance(in_ele, Layout): 192 raise TypeError(f"The {log_info} item should be a object of class Layout.") 193 layout_value += (in_ele.to_dict(),) 194 return layout_value 195 196 def _check_shard_strategy_in_out_match(self, in_strategy, out_strategy): 197 """Check shard in_strategy and out_strategy""" 198 if in_strategy is None and out_strategy is not None: 199 raise ValueError(f'The out_strategy of {self.name} is {out_strategy}, need to set in_strategy,' 200 f' but got none') 201 if not _is_in_auto_parallel_mode(): 202 mode = context.get_auto_parallel_context("parallel_mode") 203 if in_strategy is not None: 204 logger.warning(f"The in_strategy/in_layout of the operator in your network " 205 f"will not take effect in {mode} mode. " 206 f"This means the the shard function called in the network is ignored. \n" 207 f"If you want to enable it, please use semi auto or auto parallel mode by " 208 f"context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL " 209 f"or context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL)") 210 if out_strategy is not None: 211 logger.warning(f"The out_strategy/out_layout of the operator in your network " 212 f"will not take effect in {mode} mode." 213 f" This means the the shard function called in the network is ignored. \n" 214 f"If you want to enable it, please use semi auto or auto parallel mode by " 215 f"context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL " 216 f"or context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL)") 217 218 def del_prim_attr(self, name): 219 """ 220 Delete primitive attribute. 221 222 Args: 223 name (str): Attribute Name. 224 Examples: 225 >>> from mindspore import ops 226 >>> a = ops.Add() 227 >>> a = a.add_prim_attr("attr",1) 228 >>> a = a.del_prim_attr("attr") 229 >>> print(a.attrs) 230 {} 231 """ 232 if name in self.__dict__ and name in self.attrs: 233 del self.__dict__[name] 234 del self.attrs[name] 235 self.del_attr(name) 236 return self 237 238 def set_stage(self, stage): 239 """ 240 Add stage id to primitive attribute. 241 242 Note: 243 It is valid only in semi auto parallel. 244 In other parallel modes, please set it to be 0. 245 Args: 246 stage (int): The stage id for the current operation. 247 Examples: 248 >>> from mindspore import ops 249 >>> add = ops.Add() 250 >>> print(add.set_stage(0)) 251 Prim[Add]<stage=0> 252 """ 253 self.add_prim_attr("stage", stage) 254 return self 255 256 @_LogActionOnce(logger=logger, key='Primitive') 257 def shard(self, in_strategy=None, out_strategy=None): 258 """ 259 Add strategies to primitive attribute. 260 261 Note: 262 It is valid only in semi auto parallel or auto parallel mode. 263 In other parallel modes, strategies set here will be ignored. 264 265 Args: 266 in_strategy (tuple): Describe the split strategy of operator input. Default: ``None`` . 267 out_strategy (tuple): Describe the split strategy of operator output, it is only for certain operators, 268 such as MatMul. Default: ``None`` . 269 270 Examples: 271 >>> from mindspore import ops 272 >>> add = ops.Add() 273 >>> print(add.shard(((1, 1), (1, 1)))) 274 Prim[Add]<in_strategy=((1, 1), (1, 1)), out_strategy=None> 275 >>> # using layout 276 >>> from mindspore import Layout 277 >>> layout = Layout((2, 2, 2), ("dp", "sp", "mp")) 278 >>> layout_tuple = (layout("dp", "sp"), layout("sp", "mp")) 279 >>> from mindspore import ops 280 >>> matmul = ops.MatMul() 281 >>> print(matmul.shard(layout_tuple)) 282 Prim[MatMul]<in_layout=({'device_matrix': (2, 2, 2), 'tensor_map': (2, 1)}, 283 {'device_matrix': (2, 2, 2), 'tensor_map': (1, 0)})> 284 >>> # using layout with None 285 >>> from mindspore import Layout 286 >>> layout = Layout((2, 2, 2), ("dp", "sp", "mp")) 287 >>> layout_tuple = (layout("dp", "sp"), layout("sp", "None")) # "None" means the axis would not be split 288 >>> from mindspore import ops 289 >>> matmul = ops.MatMul() 290 >>> print(matmul.shard(layout_tuple)) 291 Prim[MatMul]<in_layout=({'device_matrix': (2, 2, 2), 'tensor_map': (2, 1)}, 292 {'device_matrix': (2, 2, 2), 'tensor_map': (1, -1)})> 293 """ 294 in_is_layout = None 295 out_is_layout = None 296 if in_strategy is not None: 297 in_is_layout = self._check_shard_strategy(in_strategy, "in_strategy") 298 299 if out_strategy is not None: 300 out_is_layout = self._check_shard_strategy(out_strategy, "out_strategy") 301 self._check_shard_strategy_in_out_match(in_strategy, out_strategy) 302 if in_is_layout is not None and out_is_layout is not None and in_is_layout[0] != out_is_layout[0]: 303 raise ValueError(f'The in_strategy type must equal to the out_strategy type, ' 304 f'one using tuple(tuple) and the other using tuple(Layout) is not allowed.') 305 in_layout_value = None 306 out_layout_value = None 307 if in_is_layout is not None and in_is_layout[0]: 308 in_layout_value = self._extract_layout_value(in_strategy, "in_strategy") 309 if out_is_layout is not None and out_is_layout[0]: 310 out_layout_value = self._extract_layout_value(out_strategy, "out_strategy") 311 312 313 if in_is_layout is not None and not in_is_layout[0]: 314 self.add_prim_attr("in_strategy", in_strategy) 315 if out_is_layout is not None and not out_is_layout[0]: 316 self.add_prim_attr("out_strategy", out_strategy) 317 if in_layout_value: 318 self.add_prim_attr("in_layout", in_layout_value) 319 if out_layout_value: 320 self.add_prim_attr("out_layout", out_layout_value) 321 return self 322 323 def set_prim_instance_name(self, instance_name): 324 """ 325 Set instance name to primitive operator. 326 327 Note: 328 It will be called by default when user defines primitive operator. 329 330 Args: 331 instance_name (str): Instance name of primitive operator set by user. 332 333 Examples: 334 >>> from mindspore import ops 335 >>> a = ops.Add() 336 >>> a = a.set_prim_instance_name("add") 337 >>> print(a.instance_name) 338 add 339 """ 340 self.set_instance_name(instance_name) 341 self.instance_name = instance_name 342 return self 343 344 def __getattr__(self, item): 345 if item == 'infer_dynamic_shape': 346 return None 347 if item in super().get_attr_dict(): 348 return super().get_attr_dict()[item] 349 if item in self.attrs: 350 return self.attrs[item] 351 err_msg = "'{prim}' object has no attribute '{attr}'".format(prim=self.name, attr=item) 352 raise AttributeError(err_msg) 353 354 def check_elim(self, *args): 355 """ 356 Check if the primitive can be eliminated. Subclass in need should override this method. 357 358 Args: 359 args(Primitive args): Same as arguments of current Primitive. 360 361 Returns: 362 A tuple consisting of two elements. 363 The first element means if the primitive can be calculated in compiling stage, 364 the second element is calculated result. 365 366 Examples: 367 >>> import numpy as np 368 >>> import mindspore 369 >>> from mindspore import Tensor 370 >>> from mindspore.ops import prim_attr_register, Primitive 371 >>> class AddN(Primitive): 372 ... @prim_attr_register 373 ... def __init__(self): 374 ... self.init_prim_io_names(inputs=["inputs"], outputs=["sum"]) 375 ... def check_elim(self, inputs): 376 ... if len(inputs) != 1: 377 ... return (False, None) 378 ... if isinstance(inputs[0], Tensor): 379 ... return (True, inputs[0]) 380 ... 381 >>> addn = AddN() 382 >>> input_x = Tensor(np.array([1, 2, 3]), mindspore.float32) 383 >>> output = addn.check_elim((input_x,)) 384 >>> print(output) 385 (True, Tensor(shape=[3], dtype=Float32, value= [ 1.00000000e+00, 2.00000000e+00, 3.00000000e+00])) 386 """ 387 return (False, None) 388 389 def __call__(self, *args): 390 should_elim, output = self.check_elim(*args) 391 if should_elim: 392 return output 393 return _run_op(self, self.name, args) 394 395 def __getstate__(self): 396 return self.__dict__ 397 398 def __setstate__(self, d): 399 self.__dict__.update(d) 400 401 def __deepcopy__(self, memo): 402 return type(self)(**self.init_attrs) 403 404 def __repr__(self): 405 attr = ', '.join([f'{k}={self.attrs.get(k)}' for k in self.attrs if k not in Primitive._repr_ignore_list]) 406 info_str = f'Prim[{self.name}]' 407 if attr: 408 info_str += f'<{attr}>' 409 return info_str 410 411 def init_prim_io_names(self, inputs, outputs): 412 """ 413 Initialize the name of inputs and outputs of Tensor or attributes. 414 415 Args: 416 inputs (list[str]): list of inputs names. 417 outputs (list[str]): list of outputs names. 418 Examples: 419 >>> from mindspore import ops 420 >>> a = ops.Add() 421 >>> a.init_prim_io_names(["x","y"],["sum"]) 422 >>> print(a.input_names) 423 ['x','y'] 424 >>> print(a.output_names) 425 ['sum'] 426 """ 427 # for checking para names with kernel implementation 428 self.add_prim_attr("input_names", inputs) 429 # for checking output number with kernel implementation 430 self.add_prim_attr("output_names", outputs) 431 432 @property 433 def update_parameter(self): 434 """Return whether the primitive will update the value of parameter.""" 435 return self._update_parameter 436 437 def recompute(self, mode=True): 438 """ 439 Set the primitive recomputed. If a primitive set recomputed feeds into some backward nodes 440 for computing gradient, rather than storing the intermediate activation computed in forward 441 pass, we will recompute it in backward pass. 442 443 Note: 444 445 - If the computation involves something like randomization or global variable, the equivalence 446 is not guaranteed currently. 447 - Not supported in pynative mode 448 449 Args: 450 mode (bool): Specifies whether the primitive is recomputed. Default: ``True`` . 451 452 Examples: 453 >>> import numpy as np 454 >>> import mindspore as ms 455 >>> from mindspore import Tensor, ops, nn 456 >>> class NetRecompute(nn.Cell): 457 ... def __init__(self): 458 ... super(NetRecompute,self).__init__() 459 ... self.relu = ops.ReLU().recompute() 460 ... self.sqrt = ops.Sqrt() 461 ... def construct(self, x): 462 ... out = self.relu(x) 463 ... return self.sqrt(out) 464 ... 465 >>> class GradNet(nn.Cell): 466 ... def __init__(self, network): 467 ... super(GradNet,self).__init__() 468 ... self.network = network 469 ... self.grad = ops.GradOperation() 470 ... def construct(self, x): 471 ... g_out = self.grad(self.network)(x) 472 ... return g_out 473 ... 474 >>> x = Tensor(np.array([-1,1]).astype(np.float32)) 475 >>> net = NetRecompute() 476 >>> grad = GradNet(net) 477 >>> a = grad(x) 478 >>> print(a) 479 [0. 0.5] 480 """ 481 if context.get_context("mode") == context.PYNATIVE_MODE: 482 raise TypeError("Recompute is not supported in pynative mode currently.") 483 Validator.check_bool(mode) 484 self.add_prim_attr("recompute", mode) 485 return self 486 487 def place(self, role, rank_id): 488 """ 489 Set the label for this primitive. 490 This label tells MindSpore compiler on which process this operator should be launched. 491 And each process's identical label consists of input 'role' and 'rank_id'. 492 So by setting different operators with different labels, 493 which will be launched on different processes, users can launch a distributed training job. 494 495 Note: 496 - This method is effective only after 497 "mindspore.communication.init()" is called for dynamic cluster building. 498 499 Args: 500 role (str): The role of the process on which this operator will be launched. 501 Only 'MS_WORKER' is supported for now. 502 rank_id (int): The rank id of the process on which this operator will be launched. 503 The rank_id is unique in processes with the same role. 504 505 Examples: 506 >>> from mindspore import context 507 >>> from mindspore import ops 508 >>> context.set_context(mode=context.GRAPH_MODE) 509 >>> matmul = ops.MatMul() 510 >>> matmul.place('MS_WORKER', 0) 511 """ 512 if _is_role_sched(): 513 return 514 515 Validator.check_non_negative_int(rank_id, "rank_id", "Primitive.place") 516 Validator.check_string(role, "MS_WORKER", "role", "Primitive.place") 517 518 if context.get_context("mode") == context.PYNATIVE_MODE: 519 raise RuntimeError("You are calling Primitive.place in pynative mode." 520 "It's only supported in graph mode. Please switch to graph mode.") 521 522 # Get the execution context and check whether calling of this 'place' method is valid. 523 # This is because placing operators to arbitrary processes while other distributed training mode 524 # is enabled is very unpredictable and may cause fatal error. 525 # Some of these cases are under development and others should not be supported. 526 if _is_ps_mode(): 527 raise RuntimeError( 528 "You are calling Primitive.place mixed with Parameter Server training. " 529 "This case is not supported yet. " 530 "Please call Primitive.place without Parameter Server training.") 531 if _is_in_auto_parallel_mode() or _is_in_data_parallel_mode() or _is_in_hybrid_parallel_mode(): 532 raise RuntimeError( 533 "You are calling Primitive.place mixed with other parallel features: " 534 "'auto_parallel', 'data_parallel' and 'hybrid_parallel'. " 535 "This case is still under development and not supported yet. " 536 "Please call Primitive.place without these features.") 537 self.add_prim_attr("ms_role", role) 538 self.add_prim_attr("rank_id", rank_id) 539 540 541class PrimitiveWithCheck(Primitive): 542 """ 543 PrimitiveWithCheck is the base class of primitives in python, which defines functions to check the input arguments 544 of operators, but uses the infer method registered in c++ source codes. 545 546 There are three methods can be overridden to define the check logic of the primitive: __check__(), check_shape(), 547 check_dtype(). If __check__() is defined in primitive, the __check__() has the highest priority to be called. 548 If __check__() is not defined, check_shape() and check_dtype() can be defined to describe the check logic of 549 the shape and type. Method infer_value() can also be defined (such as PrimitiveWithInfer) for constant propagation. 550 551 More on how to customize a Op, please refer to `Custom Operators 552 <https://www.mindspore.cn/tutorials/experts/en/master/operation/op_custom.html>`_. 553 554 Args: 555 name (str): Name of the current Primitive. 556 557 Supported Platforms: 558 ``Ascend`` ``GPU`` ``CPU`` 559 560 Examples: 561 >>> from mindspore import dtype as mstype 562 >>> from mindspore.ops import prim_attr_register, PrimitiveWithCheck 563 >>> # init a Primitive class with check 564 >>> class Flatten(PrimitiveWithCheck): 565 ... @prim_attr_register 566 ... def __init__(self): 567 ... pass 568 ... def check_shape(self, input_x): 569 ... Validator.check_int(len(input_x), 1, validator.GE, 'input_x rank', self.name) 570 ... 571 ... def check_dtype(self, input_x): 572 ... Validator.check_subclass("input_x", input_x, mstype.tensor_type, self.name) 573 ... 574 >>> # init a Primitive obj 575 >>> add = Flatten() 576 """ 577 578 def __init__(self, name): 579 Primitive.__init__(self, name) 580 self.set_prim_type(prim_type.py_infer_check) 581 582 def __check__(self, *args): 583 """Checking the input shape and the input type of ops is valid """ 584 check_dtype_fn = getattr(self, 'check_dtype') 585 check_dtype_fn(*(x['dtype'] for x in args)) 586 587 is_shape_known = True 588 for x in args: 589 shape = x['shape'] 590 if shape is None or -1 in shape or -2 in shape: 591 is_shape_known = False 592 break 593 if is_shape_known: 594 check_shape_fn = getattr(self, 'check_shape') 595 check_shape_fn(*(x['shape'] for x in args)) 596 597 def _clone(self): 598 """ 599 Deeply clones the primitive object. 600 601 Calls the __init__() method with the same arguments. This method is called in parser if the 602 flag self.__setattr_flag__ is True. 603 """ 604 cloned_prim = Primitive._clone(self) 605 return cloned_prim 606 607 def check_shape(self, *args): 608 """ 609 Check shapes of input args. 610 611 Note: 612 The shape of scalar is an empty tuple. 613 614 Args: 615 args (tuple(int)): shapes of input tensors. 616 617 Return: 618 None. 619 """ 620 return None 621 622 def check_dtype(self, *args): 623 """ 624 Check data types of input args. 625 626 Args: 627 args (:class:`mindspore.dtype`): data type of inputs. 628 629 Return: 630 None. 631 """ 632 return None 633 634 635class PrimitiveWithInfer(Primitive): 636 """ 637 PrimitiveWithInfer is the base class of primitives in python and defines functions for tracking inference 638 in python. 639 640 There are four method can be overridden to define the infer logic of the primitive: __infer__(), infer_shape(), 641 infer_dtype(), and infer_value(). If __infer__() is defined in primitive, the __infer__() has the highest priority 642 to be called. If __infer__() is not defined, infer_shape() and infer_dtype() can be defined to describe the infer 643 logic of the shape and type. The infer_value() is used for constant propagation. 644 645 More on how to customize a Op, please refer to `Custom Operators 646 <https://www.mindspore.cn/tutorials/experts/en/master/operation/op_custom.html>`_. 647 648 Args: 649 name (str): Name of the current Primitive. 650 651 Supported Platforms: 652 ``Ascend`` ``GPU`` ``CPU`` 653 654 Examples: 655 >>> from mindspore.ops import prim_attr_register, PrimitiveWithInfer 656 >>> # init a Primitive class with infer 657 >>> class Add(PrimitiveWithInfer): 658 ... @prim_attr_register 659 ... def __init__(self): 660 ... pass 661 ... 662 ... def infer_shape(self, x, y): 663 ... return x # output shape same as first input 'x' 664 ... 665 ... def infer_dtype(self, x, y): 666 ... return x # output type same as first input 'x' 667 ... 668 >>> # init a Primitive obj 669 >>> add = Add() 670 """ 671 672 def __init__(self, name): 673 Primitive.__init__(self, name) 674 self.set_prim_type(prim_type.py_infer_shape) 675 676 def _clone(self): 677 """ 678 Deeply clones the primitive object. 679 680 Calls the __init__() method with the same arguments. This method is called in parser if the 681 flag self.__setattr_flag__ is True. 682 """ 683 cloned_prim = Primitive._clone(self) 684 return cloned_prim 685 686 def infer_shape(self, *args): 687 """ 688 Infer output shape based on input shape. 689 690 Note: 691 The shape of scalar is an empty tuple. 692 693 Args: 694 args (tuple(int)): shapes of input tensors. 695 696 Return: 697 `tuple(int)`, shapes of output tensors. 698 """ 699 return None 700 701 def infer_dtype(self, *args): 702 """ 703 Infer output dtype based on input dtype. 704 705 Args: 706 args (:class:`mindspore.dtype`): data type of inputs. 707 708 Return: 709 :class:`mindspore.dtype`, data type of outputs. 710 """ 711 return None 712 713 def infer_value(self, *args): 714 """ 715 Infer output value based on input value at compile time. 716 717 Args: 718 args (Any): value of inputs. 719 720 Return: 721 Value of outputs. Return `None`, the value can not be inferred at compile time in this case. 722 """ 723 return None 724 725 def __infer__(self, *args): 726 """Infer shape, type, and value at the same time by using dictionary as arguments.""" 727 tracks = ['dtype', 'shape', 'value'] 728 out = {} 729 for track in tracks: 730 fn = getattr(self, 'infer_' + track) 731 # fn may return None 732 out[track] = fn(*(x[track] for x in args)) 733 734 return out 735 736 737def prim_attr_register(fn): 738 """ 739 Primitive attributes register. 740 741 Register the decorator of the built-in operator primitive '__init__'. 742 The function will add all the parameters of '__init__' as operator attributes , 743 and init primitive name. 744 745 Args: 746 fn (function): __init__ function of primitive. 747 748 Returns: 749 function, original function. 750 751 Examples: 752 >>> from mindspore.ops import prim_attr_register, PrimitiveWithCheck 753 >>> class MatMul(PrimitiveWithCheck): 754 ... @prim_attr_register 755 ... def __init__(self, transpose_a=False, transpose_b=False): 756 ... self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output']) 757 ... 758 >>> # init a Primitive obj 759 >>> matmul = MatMul() 760 """ 761 762 @functools.wraps(fn) 763 def deco(self, *args, **kwargs): 764 class_name = self.__class__.__name__ 765 if hasattr(self.__class__, "substitute_name"): 766 class_name = self.__class__.substitute_name 767 if isinstance(self, PrimitiveWithInfer): 768 PrimitiveWithInfer.__init__(self, class_name) 769 elif isinstance(self, PrimitiveWithCheck): 770 PrimitiveWithCheck.__init__(self, class_name) 771 else: 772 Primitive.__init__(self, class_name) 773 bound_args = inspect.signature(fn).bind(self, *args, **kwargs) 774 bound_args.apply_defaults() 775 arguments = bound_args.arguments 776 del arguments['self'] 777 del self.init_attrs['name'] 778 for name in arguments: 779 value = arguments[name] 780 self.add_prim_attr(name, value) 781 self.init_attrs[name] = value 782 fn(self, *args, **kwargs) 783 784 deco.decorated_func = fn 785 return deco 786 787 788def prim_arg_register(fn): 789 """ 790 Primitive attributes register. 791 792 Register the decorator of the built-in operator primitive '__init__'. 793 The function will add all the parameters of '__init__' as operator attributes , 794 and init primitive name. 795 796 Args: 797 fn (function): __init__ function of primitive. 798 799 Returns: 800 function, original function. 801 802 Examples: 803 >>> from mindspore.ops import prim_arg_register, PrimitiveWithCheck 804 >>> class MatMul(PrimitiveWithCheck): 805 ... @prim_arg_register 806 ... def __init__(self, transpose_a=False, transpose_b=False): 807 ... self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output']) 808 ... 809 >>> # init a Primitive obj 810 >>> matmul = MatMul() 811 """ 812 813 @functools.wraps(fn) 814 def deco(self, *args, **kwargs): 815 class_name = self.__class__.__name__ 816 if hasattr(self.__class__, "substitute_name"): 817 class_name = self.__class__.substitute_name 818 if isinstance(self, PrimitiveWithInfer): 819 PrimitiveWithInfer.__init__(self, class_name) 820 elif isinstance(self, PrimitiveWithCheck): 821 PrimitiveWithCheck.__init__(self, class_name) 822 else: 823 Primitive.__init__(self, self.__class__.__name__) 824 bound_args = inspect.signature(fn).bind(self, *args, **kwargs) 825 bound_args.apply_defaults() 826 arguments = bound_args.arguments 827 del arguments['self'] 828 del self.init_attrs['name'] 829 for name in arguments: 830 value = arguments[name] 831 self._set_prim_arg(name, value) 832 self.init_attrs[name] = value 833 fn(self, *args, **kwargs) 834 835 deco.decorated_func = fn 836 return deco 837 838 839def _check_contains_variable(item_dtype, item_value): 840 """ 841 Check whether the item is or contains variable. 842 """ 843 if isinstance(item_value, (list, tuple)): 844 for i, element in enumerate(item_value): 845 if _check_contains_variable(item_dtype[i], element): 846 return True 847 elif isinstance(item_value, dict): 848 if isinstance(item_dtype, typing.Keyword): 849 return item_value is None 850 for i in range(len(item_value)): 851 if _check_contains_variable(item_dtype[i], list(item_value.keys())[i]): 852 return True 853 for i in range(len(item_value)): 854 if _check_contains_variable(item_dtype[i], list(item_value.values())[i]): 855 return True 856 return item_dtype is not None and item_value is None 857 858 859def constexpr(fn=None, get_instance=True, name=None, reuse_result=True, check=True): 860 """Used to calculate constant in graph copmpiling process and improve compile performance in GRAPH_MODE. 861 862 Args: 863 fn (function): A `fn` use as the infer_value of the output operator. Default: ``None`` . 864 get_instance (bool): If ``True`` , return the instance of operator, 865 otherwise return the operator class. Default: ``True`` . 866 name (str): Defines the operator name. If `name` is ``None`` , use the function name as op name. 867 Default: ``None`` . 868 reuse_result (bool): If ``True`` , the operator will be executed once and reuse the result next time, 869 otherwise the operator will always be executed. Default: ``True`` . 870 check (bool): If ``True`` , the parameters will be checked 871 and the warning message will raised if the parameter is not const value. Default: ``True`` . 872 873 Examples: 874 875 >>> import mindspore as ms 876 >>> # define a constant calculate function with for loop inside and use use constexpr to accelerate the compile 877 >>> # process. 878 >>> @ms.constexpr 879 ... def for_loop_calculate(range_num): 880 ... out = 0 881 ... for i in range(range_num): 882 ... if i %2 == 0 and i % 7 != 0: 883 ... out = out + i 884 ... return out // range_num 885 ... 886 >>> # construct a net and run with GRAPH_MODE. 887 >>> @ms.jit 888 ... def my_func(x): 889 ... new_shape = for_loop_calculate(100000) 890 ... return ms.ops.broadcast_to(x, (new_shape, )) 891 ... 892 >>> out = my_func(ms.Tensor([1])) 893 >>> print(out.shape) 894 >>> (21428, ) 895 """ 896 897 def decorator(fn): 898 """Decorator for ProxyOp.""" 899 900 class ProxyOp(PrimitiveWithInfer): 901 """ 902 ProxyOp is a temporary operator used to execute the constexpr function. 903 """ 904 905 def __init__(self): 906 op_name = name if name else fn.__name__ 907 super(ProxyOp, self).__init__(op_name) 908 self.set_const_prim(True) 909 self.fn = fn 910 self.add_prim_attr('constexpr_prim', True) 911 if not reuse_result: 912 self.add_prim_attr('forbid_reuse_result', True) 913 914 def __infer__(self, *args): 915 value_args = [] 916 for item in args: 917 item_value = item["value"] 918 if _check_contains_variable(item["dtype"], item_value) and check: 919 logger.warning("The \"" + self.name + "\" is a constexpr function." \ 920 " The input arguments must be all constant value.") 921 value_args.append(item_value) 922 return {'dtype': None, 'shape': None, 'value': fn(*value_args)} 923 924 def __call__(self, *args, **kwargs): 925 return fn(*args, **kwargs) 926 927 if get_instance: 928 return ProxyOp() 929 return ProxyOp 930 931 if fn is not None: 932 return decorator(fn) 933 return decorator 934 935 936def _primexpr(fn=None, get_instance=True, name=None, reuse_result=True): 937 """ 938 _primexpr is similar as constexpr except that when the input to the function decorated by _primexpr contains 939 variable, the function will be compiled as graph. 940 941 _primexpr is only for internal use. 942 943 Args: 944 fn (function): A `fn` use as the infer_value of the output operator. Default: ``None`` . 945 get_instance (bool): If ``True`` , return the instance of operator, 946 otherwise return the operator class. Default: ``True`` . 947 name (str): Defines the operator name. If `name` is ``None`` , use the function name as op name. 948 Default: ``None`` . 949 reuse_result (bool): If ``True`` , the operator will be executed once and reuse the result next time, 950 otherwise the operator will always be executed. Default: ``True`` . 951 """ 952 953 def deco(fn): 954 """Decorator for CompileOp.""" 955 956 class CompileOp(PrimitiveWithInfer): 957 """ 958 CompileOp is a temporary operator used to execute the constexpr function. 959 """ 960 961 def __init__(self): 962 op_name = name if name else fn.__name__ 963 PrimitiveWithInfer.__init__(self, op_name) 964 self.set_const_prim(True) 965 self.fn = fn 966 self.add_prim_attr('constexpr_prim', True) 967 if not reuse_result: 968 self.add_prim_attr('forbid_reuse_result', True) 969 970 def __infer__(self, *args): 971 value_args = [] 972 for item in args: 973 if _check_contains_variable(item["dtype"], item["value"]): 974 return {'dtype': None, 'shape': None, 'value': None, 'fn': (fn,)} 975 value_args.append(item["value"]) 976 return {'dtype': None, 'shape': None, 'value': fn(*value_args)} 977 978 def __call__(self, *args, **kwargs): 979 return fn(*args, **kwargs) 980 981 if get_instance: 982 return CompileOp() 983 return CompileOp 984 985 if fn is not None: 986 return deco(fn) 987 return deco 988 989 990class _RunOpHook: 991 """Hook for run op""" 992 993 current = None 994 995 def __init__(self, hook): 996 self.hook = hook 997 self.old = _RunOpHook.current 998 999 def __enter__(self): 1000 _RunOpHook.current = self 1001 return self 1002 1003 def __exit__(self, *err): 1004 _RunOpHook.current = self.old 1005 1006 1007def _run_op(obj, op_name, args): 1008 """Single op execution function supported by ge in PyNative mode.""" 1009 if not _RunOpHook.current: 1010 stub = _pynative_executor.run_op_async(obj, op_name, args) 1011 return _convert_stub(stub) 1012 return _RunOpHook.current.hook(obj, args) 1013 1014 1015@_wrap_func 1016def _run_op_sync(obj, op_name, args): 1017 """Single op execution function in synchronous mode.""" 1018 output = _pynative_executor.real_run_op(obj, op_name, args) 1019 return output 1020 1021 1022class _PrimitiveC(Primitive): 1023 def __init__(self, name, attrs): 1024 super().__init__(name) 1025 for key, value in attrs.items(): 1026 super().add_prim_attr(key, value) 1027 1028 1029def _get_primitivec(name, attrs): 1030 return _PrimitiveC(name, attrs) 1031 1032 1033def _create_primitive_function_obj(): 1034 return PrimitiveFunction_() 1035