• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2024 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""primitive"""
17import functools
18import inspect
19import copy
20import numpy as np
21from mindspore.common.api import _wrap_func
22from mindspore.log import _LogActionOnce
23from mindspore import context, log as logger
24from mindspore.parallel._utils import _is_in_auto_parallel_mode, _is_in_data_parallel_mode, \
25    _is_in_hybrid_parallel_mode, SUPPORTED_TUPLE_IN_TUPLE_STRATEGY
26from mindspore.parallel._ps_context import _is_ps_mode, _is_role_sched
27from mindspore.parallel.shard import Layout
28from mindspore.common.api import _pynative_executor
29from mindspore.common._stub_tensor import _convert_stub
30from mindspore._c_expression import Primitive_, PrimitiveFunction_, prim_type, typing
31from mindspore import _checkparam as Validator
32from mindspore.ops import signature as sig
33
34
35class Primitive(Primitive_):
36    """
37    Primitive is the base class of operator primitives in python.
38
39    Args:
40        name (str): Name for the current Primitive.
41
42    Examples:
43        >>> from mindspore.ops import prim_attr_register, Primitive
44        >>> add = Primitive('add')
45        >>>
46        >>> # or work with prim_attr_register:
47        >>> # init a Primitive class with attr1 and attr2
48        >>> class Add(Primitive):
49        ...     @prim_attr_register
50        ...     def __init__(self, attr1, attr2):
51        ...         '''init for add'''
52        ...     # check attr1 and attr2 or do some initializations
53        ...     # init a Primitive obj with attr1=1 and attr2=2
54        >>> add = Add(attr1=1, attr2=2)
55    """
56    _repr_ignore_list = ['input_names', 'output_names']
57
58    def __init__(self, name):
59        self.name = name
60        self.attrs = {}
61        self.init_attrs = {"name": name}
62        self._update_parameter = False
63        Primitive_.__init__(self, name)
64        if hasattr(self.__class__, '__mindspore_signature__'):
65            out = self._fill_signature(self.__class__.__mindspore_signature__)
66            self.set_signatures(out)
67
68    def add_prim_attr(self, name, value):
69        """
70        Add primitive attribute.
71
72        Args:
73            name (str): Attribute Name.
74            value (Any): Attribute value.
75
76        Examples:
77            >>> from mindspore import ops
78            >>> a = ops.Add()
79            >>> a = a.add_prim_attr("attr",1)
80            >>> out = a.attrs["attr"]
81            >>> print(out)
82            1
83        """
84        self.__dict__[name] = value
85        self.attrs[name] = value
86        self.add_attr(name, value)
87        return self
88
89    def _set_prim_arg(self, name, value):
90        """
91        Set primitive initialization arguments.
92
93        Different from add_prim_attr, it is used internally to store Primitive
94        initialization arguments in Python.
95        """
96        self.__dict__[name] = value
97        self.attrs[name] = value
98        return self
99
100    def _set_prim_arg_with_handler(self, name, value, arg_handler):
101        """
102        Set primitive initialization arguments and with arg_handler.
103        """
104        value = value if value is None else arg_handler(self.__class__.__name__, name, value)
105        return self._set_prim_arg(name, value)
106
107    def set_device(self, device_target):
108        """
109        Set primitive been executed device.
110
111        Args:
112            device_target (str): The target device to run, support "Ascend", "GPU", and "CPU".
113
114        Examples:
115            >>> from mindspore import ops
116            >>> a = ops.Add()
117            >>> a = a.set_device("GPU")
118            >>> print(a.primitive_target)
119            GPU
120        """
121        return self.add_prim_attr("primitive_target", device_target)
122
123    def _fill_signature(self, signatures):
124        """fills signature."""
125        signatures_new = []
126        for signature in signatures:
127            if isinstance(signature, sig.Signature):
128                signatures_new.append(signature)
129            elif isinstance(signature, sig.sig_dtype):
130                signatures_new.append(sig.make_sig(dtype=signature))
131            else:
132                if len(signature) < 3:
133                    raise ValueError(f"[Internal Error]Signature for one parameter len must > 3, but {signature}")
134                signatures_new.append(sig.make_sig(*signature))
135        return tuple(signatures_new)
136
137    def _clone(self):
138        """
139        Deeply clones the primitive object.
140
141        Calls the __init__() method with the same arguments. This method is called in parser if the
142        flag self.__setattr_flag__ is True.
143        """
144        cloned = copy.deepcopy(self)
145        init_params = list()
146        if hasattr(cloned.__init__, 'decorated_func'):
147            init_params = inspect.getfullargspec(cloned.__init__.decorated_func).args[1:]
148        init_args = self.init_attrs
149        for name in init_params:
150            value = self.attrs[name]
151            init_args[name] = value
152        # __init__ should be called to construct cpp object.
153        cloned.__init__(**init_args)
154        for name in self.attrs:
155            value = self.attrs[name]
156            cloned.add_prim_attr(name, value)
157        if hasattr(self, 'instance_name'):
158            cloned.set_prim_instance_name(self.instance_name)
159        return cloned
160
161    def _check_shard_strategy(self, strategy, log_info):
162        """Check shard strategy is validate or not"""
163        is_layout = []
164        if not isinstance(strategy, tuple):
165            raise TypeError(f'{log_info} must be tuple type, but got:{type(strategy)}')
166        for in_ele in strategy:
167            if not isinstance(in_ele, tuple) and not isinstance(in_ele, Layout):
168                raise TypeError(f'The element of strategy must be tuple/Layout type, but got:{type(in_ele)}')
169            if isinstance(in_ele, tuple):
170                for in_value in in_ele:
171                    if not isinstance(in_value, int) and self.name not in SUPPORTED_TUPLE_IN_TUPLE_STRATEGY:
172                        raise TypeError(f'The {log_info}: {strategy} of {self.name} is not valid,'
173                                        f' the value of strategy must be int type, but got:{type(in_value)}')
174                is_layout.append(False)
175                continue
176            is_layout.append(True)
177        if not is_layout:
178            np_is_layout = np.array(is_layout)
179            if not (np_is_layout == np_is_layout[0]).all():
180                raise TypeError(f'{log_info} item must be all tuple type or all Layout type.')
181        return np.array(is_layout)
182
183    def _extract_layout_value(self, layout, log_info):
184        """Extract parallel layout value"""
185        layout_value = None
186        if layout is not None:
187            if not isinstance(layout, tuple):
188                raise TypeError(f'{log_info} must be tuple type, but got:{type(layout)}')
189            layout_value = ()
190            for in_ele in layout:
191                if not isinstance(in_ele, Layout):
192                    raise TypeError(f"The {log_info} item should be a object of class Layout.")
193                layout_value += (in_ele.to_dict(),)
194        return layout_value
195
196    def _check_shard_strategy_in_out_match(self, in_strategy, out_strategy):
197        """Check shard in_strategy and out_strategy"""
198        if in_strategy is None and out_strategy is not None:
199            raise ValueError(f'The out_strategy of {self.name} is {out_strategy}, need to set in_strategy,'
200                             f' but got none')
201        if not _is_in_auto_parallel_mode():
202            mode = context.get_auto_parallel_context("parallel_mode")
203            if in_strategy is not None:
204                logger.warning(f"The in_strategy/in_layout of the operator in your network "
205                               f"will not take effect in {mode} mode. "
206                               f"This means the the shard function called in the network is ignored. \n"
207                               f"If you want to enable it, please use semi auto or auto parallel mode by "
208                               f"context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL "
209                               f"or context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL)")
210            if out_strategy is not None:
211                logger.warning(f"The out_strategy/out_layout of the operator in your network "
212                               f"will not take effect in {mode} mode."
213                               f" This means the the shard function called in the network is ignored. \n"
214                               f"If you want to enable it, please use semi auto or auto parallel mode by "
215                               f"context.set_auto_parallel_context(parallel_mode=ParallelMode.SEMI_AUTO_PARALLEL "
216                               f"or context.set_auto_parallel_context(parallel_mode=ParallelMode.AUTO_PARALLEL)")
217
218    def del_prim_attr(self, name):
219        """
220        Delete primitive attribute.
221
222        Args:
223            name (str): Attribute Name.
224        Examples:
225            >>> from mindspore import ops
226            >>> a = ops.Add()
227            >>> a = a.add_prim_attr("attr",1)
228            >>> a = a.del_prim_attr("attr")
229            >>> print(a.attrs)
230            {}
231        """
232        if name in self.__dict__ and name in self.attrs:
233            del self.__dict__[name]
234            del self.attrs[name]
235            self.del_attr(name)
236        return self
237
238    def set_stage(self, stage):
239        """
240        Add stage id to primitive attribute.
241
242        Note:
243            It is valid only in semi auto parallel.
244            In other parallel modes, please set it to be 0.
245        Args:
246            stage (int): The stage id for the current operation.
247        Examples:
248            >>> from mindspore import ops
249            >>> add = ops.Add()
250            >>> print(add.set_stage(0))
251            Prim[Add]<stage=0>
252        """
253        self.add_prim_attr("stage", stage)
254        return self
255
256    @_LogActionOnce(logger=logger, key='Primitive')
257    def shard(self, in_strategy=None, out_strategy=None):
258        """
259        Add strategies to primitive attribute.
260
261        Note:
262            It is valid only in semi auto parallel or auto parallel mode.
263            In other parallel modes, strategies set here will be ignored.
264
265        Args:
266            in_strategy (tuple): Describe the split strategy of operator input. Default: ``None`` .
267            out_strategy (tuple): Describe the split strategy of operator output, it is only for certain operators,
268                                  such as MatMul. Default: ``None`` .
269
270        Examples:
271            >>> from mindspore import ops
272            >>> add = ops.Add()
273            >>> print(add.shard(((1, 1), (1, 1))))
274            Prim[Add]<in_strategy=((1, 1), (1, 1)), out_strategy=None>
275            >>> # using layout
276            >>> from mindspore import Layout
277            >>> layout = Layout((2, 2, 2), ("dp", "sp", "mp"))
278            >>> layout_tuple = (layout("dp", "sp"), layout("sp", "mp"))
279            >>> from mindspore import ops
280            >>> matmul = ops.MatMul()
281            >>> print(matmul.shard(layout_tuple))
282            Prim[MatMul]<in_layout=({'device_matrix': (2, 2, 2), 'tensor_map': (2, 1)},
283            {'device_matrix': (2, 2, 2), 'tensor_map': (1, 0)})>
284            >>> # using layout with None
285            >>> from mindspore import Layout
286            >>> layout = Layout((2, 2, 2), ("dp", "sp", "mp"))
287            >>> layout_tuple = (layout("dp", "sp"), layout("sp", "None")) # "None" means the axis would not be split
288            >>> from mindspore import ops
289            >>> matmul = ops.MatMul()
290            >>> print(matmul.shard(layout_tuple))
291            Prim[MatMul]<in_layout=({'device_matrix': (2, 2, 2), 'tensor_map': (2, 1)},
292            {'device_matrix': (2, 2, 2), 'tensor_map': (1, -1)})>
293        """
294        in_is_layout = None
295        out_is_layout = None
296        if in_strategy is not None:
297            in_is_layout = self._check_shard_strategy(in_strategy, "in_strategy")
298
299        if out_strategy is not None:
300            out_is_layout = self._check_shard_strategy(out_strategy, "out_strategy")
301        self._check_shard_strategy_in_out_match(in_strategy, out_strategy)
302        if in_is_layout is not None and out_is_layout is not None and in_is_layout[0] != out_is_layout[0]:
303            raise ValueError(f'The in_strategy type must equal to the out_strategy type, '
304                             f'one using tuple(tuple) and the other using tuple(Layout) is not allowed.')
305        in_layout_value = None
306        out_layout_value = None
307        if in_is_layout is not None and in_is_layout[0]:
308            in_layout_value = self._extract_layout_value(in_strategy, "in_strategy")
309        if out_is_layout is not None and out_is_layout[0]:
310            out_layout_value = self._extract_layout_value(out_strategy, "out_strategy")
311
312
313        if in_is_layout is not None and not in_is_layout[0]:
314            self.add_prim_attr("in_strategy", in_strategy)
315        if out_is_layout is not None and not out_is_layout[0]:
316            self.add_prim_attr("out_strategy", out_strategy)
317        if in_layout_value:
318            self.add_prim_attr("in_layout", in_layout_value)
319        if out_layout_value:
320            self.add_prim_attr("out_layout", out_layout_value)
321        return self
322
323    def set_prim_instance_name(self, instance_name):
324        """
325        Set instance name to primitive operator.
326
327        Note:
328            It will be called by default when user defines primitive operator.
329
330        Args:
331            instance_name (str): Instance name of primitive operator set by user.
332
333        Examples:
334            >>> from mindspore import ops
335            >>> a = ops.Add()
336            >>> a = a.set_prim_instance_name("add")
337            >>> print(a.instance_name)
338            add
339        """
340        self.set_instance_name(instance_name)
341        self.instance_name = instance_name
342        return self
343
344    def __getattr__(self, item):
345        if item == 'infer_dynamic_shape':
346            return None
347        if item in super().get_attr_dict():
348            return super().get_attr_dict()[item]
349        if item in self.attrs:
350            return self.attrs[item]
351        err_msg = "'{prim}' object has no attribute '{attr}'".format(prim=self.name, attr=item)
352        raise AttributeError(err_msg)
353
354    def check_elim(self, *args):
355        """
356        Check if the primitive can be eliminated. Subclass in need should override this method.
357
358        Args:
359            args(Primitive args): Same as arguments of current Primitive.
360
361        Returns:
362            A tuple consisting of two elements.
363            The first element means if the primitive can be calculated in compiling stage,
364            the second element is calculated result.
365
366        Examples:
367            >>> import numpy as np
368            >>> import mindspore
369            >>> from mindspore import Tensor
370            >>> from mindspore.ops import prim_attr_register, Primitive
371            >>> class AddN(Primitive):
372            ...     @prim_attr_register
373            ...     def __init__(self):
374            ...         self.init_prim_io_names(inputs=["inputs"], outputs=["sum"])
375            ...     def check_elim(self, inputs):
376            ...         if len(inputs) != 1:
377            ...             return (False, None)
378            ...         if isinstance(inputs[0], Tensor):
379            ...             return (True, inputs[0])
380            ...
381            >>> addn = AddN()
382            >>> input_x = Tensor(np.array([1, 2, 3]), mindspore.float32)
383            >>> output = addn.check_elim((input_x,))
384            >>> print(output)
385            (True, Tensor(shape=[3], dtype=Float32, value= [ 1.00000000e+00,  2.00000000e+00,  3.00000000e+00]))
386    """
387        return (False, None)
388
389    def __call__(self, *args):
390        should_elim, output = self.check_elim(*args)
391        if should_elim:
392            return output
393        return _run_op(self, self.name, args)
394
395    def __getstate__(self):
396        return self.__dict__
397
398    def __setstate__(self, d):
399        self.__dict__.update(d)
400
401    def __deepcopy__(self, memo):
402        return type(self)(**self.init_attrs)
403
404    def __repr__(self):
405        attr = ', '.join([f'{k}={self.attrs.get(k)}' for k in self.attrs if k not in Primitive._repr_ignore_list])
406        info_str = f'Prim[{self.name}]'
407        if attr:
408            info_str += f'<{attr}>'
409        return info_str
410
411    def init_prim_io_names(self, inputs, outputs):
412        """
413        Initialize the name of inputs and outputs of Tensor or attributes.
414
415        Args:
416            inputs (list[str]): list of inputs names.
417            outputs (list[str]): list of outputs names.
418        Examples:
419            >>> from mindspore import ops
420            >>> a = ops.Add()
421            >>> a.init_prim_io_names(["x","y"],["sum"])
422            >>> print(a.input_names)
423            ['x','y']
424            >>> print(a.output_names)
425            ['sum']
426        """
427        # for checking para names with kernel implementation
428        self.add_prim_attr("input_names", inputs)
429        # for checking output number with kernel implementation
430        self.add_prim_attr("output_names", outputs)
431
432    @property
433    def update_parameter(self):
434        """Return whether the primitive will update the value of parameter."""
435        return self._update_parameter
436
437    def recompute(self, mode=True):
438        """
439        Set the primitive recomputed. If a primitive set recomputed feeds into some backward nodes
440        for computing gradient, rather than storing the intermediate activation computed in forward
441        pass, we will recompute it in backward pass.
442
443        Note:
444
445            - If the computation involves something like randomization or global variable, the equivalence
446              is not guaranteed currently.
447            - Not supported in pynative mode
448
449        Args:
450            mode (bool): Specifies whether the primitive is recomputed. Default: ``True`` .
451
452        Examples:
453            >>> import numpy as np
454            >>> import mindspore as ms
455            >>> from mindspore import Tensor, ops, nn
456            >>> class NetRecompute(nn.Cell):
457            ...     def __init__(self):
458            ...         super(NetRecompute,self).__init__()
459            ...         self.relu = ops.ReLU().recompute()
460            ...         self.sqrt = ops.Sqrt()
461            ...     def construct(self, x):
462            ...         out = self.relu(x)
463            ...         return self.sqrt(out)
464            ...
465            >>> class GradNet(nn.Cell):
466            ...     def __init__(self, network):
467            ...         super(GradNet,self).__init__()
468            ...         self.network = network
469            ...         self.grad = ops.GradOperation()
470            ...     def construct(self, x):
471            ...         g_out = self.grad(self.network)(x)
472            ...         return g_out
473            ...
474            >>> x = Tensor(np.array([-1,1]).astype(np.float32))
475            >>> net = NetRecompute()
476            >>> grad = GradNet(net)
477            >>> a = grad(x)
478            >>> print(a)
479            [0. 0.5]
480        """
481        if context.get_context("mode") == context.PYNATIVE_MODE:
482            raise TypeError("Recompute is not supported in pynative mode currently.")
483        Validator.check_bool(mode)
484        self.add_prim_attr("recompute", mode)
485        return self
486
487    def place(self, role, rank_id):
488        """
489        Set the label for this primitive.
490        This label tells MindSpore compiler on which process this operator should be launched.
491        And each process's identical label consists of input 'role' and 'rank_id'.
492        So by setting different operators with different labels,
493        which will be launched on different processes, users can launch a distributed training job.
494
495        Note:
496            - This method is effective only after
497              "mindspore.communication.init()" is called for dynamic cluster building.
498
499        Args:
500            role (str): The role of the process on which this operator will be launched.
501                        Only 'MS_WORKER' is supported for now.
502            rank_id (int): The rank id of the process on which this operator will be launched.
503                           The rank_id is unique in processes with the same role.
504
505        Examples:
506            >>> from mindspore import context
507            >>> from mindspore import ops
508            >>> context.set_context(mode=context.GRAPH_MODE)
509            >>> matmul = ops.MatMul()
510            >>> matmul.place('MS_WORKER', 0)
511        """
512        if _is_role_sched():
513            return
514
515        Validator.check_non_negative_int(rank_id, "rank_id", "Primitive.place")
516        Validator.check_string(role, "MS_WORKER", "role", "Primitive.place")
517
518        if context.get_context("mode") == context.PYNATIVE_MODE:
519            raise RuntimeError("You are calling Primitive.place in pynative mode."
520                               "It's only supported in graph mode. Please switch to graph mode.")
521
522        # Get the execution context and check whether calling of this 'place' method is valid.
523        # This is because placing operators to arbitrary processes while other distributed training mode
524        # is enabled is very unpredictable and may cause fatal error.
525        # Some of these cases are under development and others should not be supported.
526        if _is_ps_mode():
527            raise RuntimeError(
528                "You are calling Primitive.place mixed with Parameter Server training. "
529                "This case is not supported yet. "
530                "Please call Primitive.place without Parameter Server training.")
531        if _is_in_auto_parallel_mode() or _is_in_data_parallel_mode() or _is_in_hybrid_parallel_mode():
532            raise RuntimeError(
533                "You are calling Primitive.place mixed with other parallel features: "
534                "'auto_parallel', 'data_parallel' and 'hybrid_parallel'. "
535                "This case is still under development and not supported yet. "
536                "Please call Primitive.place without these features.")
537        self.add_prim_attr("ms_role", role)
538        self.add_prim_attr("rank_id", rank_id)
539
540
541class PrimitiveWithCheck(Primitive):
542    """
543    PrimitiveWithCheck is the base class of primitives in python, which defines functions to check the input arguments
544    of operators, but uses the infer method registered in c++ source codes.
545
546    There are three methods can be overridden to define the check logic of the primitive: __check__(), check_shape(),
547    check_dtype(). If __check__() is defined in primitive, the __check__() has the highest priority to be called.
548    If __check__() is not defined, check_shape() and check_dtype() can be defined to describe the check logic of
549    the shape and type. Method infer_value() can also be defined (such as PrimitiveWithInfer) for constant propagation.
550
551    More on how to customize a Op, please refer to `Custom Operators
552    <https://www.mindspore.cn/tutorials/experts/en/master/operation/op_custom.html>`_.
553
554    Args:
555        name (str): Name of the current Primitive.
556
557    Supported Platforms:
558        ``Ascend`` ``GPU`` ``CPU``
559
560    Examples:
561        >>> from mindspore import dtype as mstype
562        >>> from mindspore.ops import prim_attr_register, PrimitiveWithCheck
563        >>> # init a Primitive class with check
564        >>> class Flatten(PrimitiveWithCheck):
565        ...     @prim_attr_register
566        ...     def __init__(self):
567        ...         pass
568        ...     def check_shape(self, input_x):
569        ...         Validator.check_int(len(input_x), 1, validator.GE, 'input_x rank', self.name)
570        ...
571        ...     def check_dtype(self, input_x):
572        ...         Validator.check_subclass("input_x", input_x, mstype.tensor_type, self.name)
573        ...
574        >>> # init a Primitive obj
575        >>> add = Flatten()
576    """
577
578    def __init__(self, name):
579        Primitive.__init__(self, name)
580        self.set_prim_type(prim_type.py_infer_check)
581
582    def __check__(self, *args):
583        """Checking the input shape and the input type of ops is valid """
584        check_dtype_fn = getattr(self, 'check_dtype')
585        check_dtype_fn(*(x['dtype'] for x in args))
586
587        is_shape_known = True
588        for x in args:
589            shape = x['shape']
590            if shape is None or -1 in shape or -2 in shape:
591                is_shape_known = False
592                break
593        if is_shape_known:
594            check_shape_fn = getattr(self, 'check_shape')
595            check_shape_fn(*(x['shape'] for x in args))
596
597    def _clone(self):
598        """
599        Deeply clones the primitive object.
600
601        Calls the __init__() method with the same arguments. This method is called in parser if the
602        flag self.__setattr_flag__ is True.
603        """
604        cloned_prim = Primitive._clone(self)
605        return cloned_prim
606
607    def check_shape(self, *args):
608        """
609        Check shapes of input args.
610
611        Note:
612            The shape of scalar is an empty tuple.
613
614        Args:
615            args (tuple(int)): shapes of input tensors.
616
617        Return:
618            None.
619        """
620        return None
621
622    def check_dtype(self, *args):
623        """
624        Check data types of input args.
625
626        Args:
627            args (:class:`mindspore.dtype`): data type of inputs.
628
629        Return:
630            None.
631        """
632        return None
633
634
635class PrimitiveWithInfer(Primitive):
636    """
637    PrimitiveWithInfer is the base class of primitives in python and defines functions for tracking inference
638    in python.
639
640    There are four method can be overridden to define the infer logic of the primitive: __infer__(), infer_shape(),
641    infer_dtype(), and infer_value(). If __infer__() is defined in primitive, the __infer__() has the highest priority
642    to be called. If __infer__() is not defined, infer_shape() and infer_dtype() can be defined to describe the infer
643    logic of the shape and type. The infer_value() is used for constant propagation.
644
645    More on how to customize a Op, please refer to `Custom Operators
646    <https://www.mindspore.cn/tutorials/experts/en/master/operation/op_custom.html>`_.
647
648    Args:
649        name (str): Name of the current Primitive.
650
651    Supported Platforms:
652        ``Ascend`` ``GPU`` ``CPU``
653
654    Examples:
655        >>> from mindspore.ops import prim_attr_register, PrimitiveWithInfer
656        >>> # init a Primitive class with infer
657        >>> class Add(PrimitiveWithInfer):
658        ...     @prim_attr_register
659        ...     def __init__(self):
660        ...         pass
661        ...
662        ...     def infer_shape(self, x, y):
663        ...         return x # output shape same as first input 'x'
664        ...
665        ...     def infer_dtype(self, x, y):
666        ...         return x # output type same as first input 'x'
667        ...
668        >>> # init a Primitive obj
669        >>> add = Add()
670    """
671
672    def __init__(self, name):
673        Primitive.__init__(self, name)
674        self.set_prim_type(prim_type.py_infer_shape)
675
676    def _clone(self):
677        """
678        Deeply clones the primitive object.
679
680        Calls the __init__() method with the same arguments. This method is called in parser if the
681        flag self.__setattr_flag__ is True.
682        """
683        cloned_prim = Primitive._clone(self)
684        return cloned_prim
685
686    def infer_shape(self, *args):
687        """
688        Infer output shape based on input shape.
689
690        Note:
691            The shape of scalar is an empty tuple.
692
693        Args:
694            args (tuple(int)): shapes of input tensors.
695
696        Return:
697            `tuple(int)`, shapes of output tensors.
698        """
699        return None
700
701    def infer_dtype(self, *args):
702        """
703        Infer output dtype based on input dtype.
704
705        Args:
706            args (:class:`mindspore.dtype`): data type of inputs.
707
708        Return:
709            :class:`mindspore.dtype`, data type of outputs.
710        """
711        return None
712
713    def infer_value(self, *args):
714        """
715        Infer output value based on input value at compile time.
716
717        Args:
718            args (Any): value of inputs.
719
720        Return:
721            Value of outputs. Return `None`, the value can not be inferred at compile time in this case.
722        """
723        return None
724
725    def __infer__(self, *args):
726        """Infer shape, type, and value at the same time by using dictionary as arguments."""
727        tracks = ['dtype', 'shape', 'value']
728        out = {}
729        for track in tracks:
730            fn = getattr(self, 'infer_' + track)
731            # fn may return None
732            out[track] = fn(*(x[track] for x in args))
733
734        return out
735
736
737def prim_attr_register(fn):
738    """
739    Primitive attributes register.
740
741    Register the decorator of the built-in operator primitive '__init__'.
742    The function will add all the parameters of '__init__' as operator attributes ,
743    and init primitive name.
744
745    Args:
746        fn (function): __init__ function of primitive.
747
748    Returns:
749        function, original function.
750
751    Examples:
752        >>> from mindspore.ops import prim_attr_register, PrimitiveWithCheck
753        >>> class MatMul(PrimitiveWithCheck):
754        ...     @prim_attr_register
755        ...     def __init__(self, transpose_a=False, transpose_b=False):
756        ...         self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output'])
757        ...
758        >>> # init a Primitive obj
759        >>> matmul = MatMul()
760    """
761
762    @functools.wraps(fn)
763    def deco(self, *args, **kwargs):
764        class_name = self.__class__.__name__
765        if hasattr(self.__class__, "substitute_name"):
766            class_name = self.__class__.substitute_name
767        if isinstance(self, PrimitiveWithInfer):
768            PrimitiveWithInfer.__init__(self, class_name)
769        elif isinstance(self, PrimitiveWithCheck):
770            PrimitiveWithCheck.__init__(self, class_name)
771        else:
772            Primitive.__init__(self, class_name)
773        bound_args = inspect.signature(fn).bind(self, *args, **kwargs)
774        bound_args.apply_defaults()
775        arguments = bound_args.arguments
776        del arguments['self']
777        del self.init_attrs['name']
778        for name in arguments:
779            value = arguments[name]
780            self.add_prim_attr(name, value)
781            self.init_attrs[name] = value
782        fn(self, *args, **kwargs)
783
784    deco.decorated_func = fn
785    return deco
786
787
788def prim_arg_register(fn):
789    """
790    Primitive attributes register.
791
792    Register the decorator of the built-in operator primitive '__init__'.
793    The function will add all the parameters of '__init__' as operator attributes ,
794    and init primitive name.
795
796    Args:
797        fn (function): __init__ function of primitive.
798
799    Returns:
800        function, original function.
801
802    Examples:
803        >>> from mindspore.ops import prim_arg_register, PrimitiveWithCheck
804        >>> class MatMul(PrimitiveWithCheck):
805        ...     @prim_arg_register
806        ...     def __init__(self, transpose_a=False, transpose_b=False):
807        ...         self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output'])
808        ...
809        >>> # init a Primitive obj
810        >>> matmul = MatMul()
811    """
812
813    @functools.wraps(fn)
814    def deco(self, *args, **kwargs):
815        class_name = self.__class__.__name__
816        if hasattr(self.__class__, "substitute_name"):
817            class_name = self.__class__.substitute_name
818        if isinstance(self, PrimitiveWithInfer):
819            PrimitiveWithInfer.__init__(self, class_name)
820        elif isinstance(self, PrimitiveWithCheck):
821            PrimitiveWithCheck.__init__(self, class_name)
822        else:
823            Primitive.__init__(self, self.__class__.__name__)
824        bound_args = inspect.signature(fn).bind(self, *args, **kwargs)
825        bound_args.apply_defaults()
826        arguments = bound_args.arguments
827        del arguments['self']
828        del self.init_attrs['name']
829        for name in arguments:
830            value = arguments[name]
831            self._set_prim_arg(name, value)
832            self.init_attrs[name] = value
833        fn(self, *args, **kwargs)
834
835    deco.decorated_func = fn
836    return deco
837
838
839def _check_contains_variable(item_dtype, item_value):
840    """
841    Check whether the item is or contains variable.
842    """
843    if isinstance(item_value, (list, tuple)):
844        for i, element in enumerate(item_value):
845            if _check_contains_variable(item_dtype[i], element):
846                return True
847    elif isinstance(item_value, dict):
848        if isinstance(item_dtype, typing.Keyword):
849            return item_value is None
850        for i in range(len(item_value)):
851            if _check_contains_variable(item_dtype[i], list(item_value.keys())[i]):
852                return True
853        for i in range(len(item_value)):
854            if _check_contains_variable(item_dtype[i], list(item_value.values())[i]):
855                return True
856    return item_dtype is not None and item_value is None
857
858
859def constexpr(fn=None, get_instance=True, name=None, reuse_result=True, check=True):
860    """Used to calculate constant in graph copmpiling process and improve compile performance in GRAPH_MODE.
861
862    Args:
863        fn (function): A `fn` use as the infer_value of the output operator. Default: ``None`` .
864        get_instance (bool): If ``True`` , return the instance of operator,
865                             otherwise return the operator class. Default: ``True`` .
866        name (str): Defines the operator name. If `name` is ``None`` , use the function name as op name.
867                             Default: ``None`` .
868        reuse_result (bool): If ``True`` , the operator will be executed once and reuse the result next time,
869                             otherwise the operator will always be executed. Default: ``True`` .
870        check (bool): If ``True`` , the parameters will be checked
871            and the warning message will raised if the parameter is not const value. Default: ``True`` .
872
873    Examples:
874
875        >>> import mindspore as ms
876        >>> # define a constant calculate function with for loop inside and use use constexpr to accelerate the compile
877        >>> # process.
878        >>> @ms.constexpr
879        ... def for_loop_calculate(range_num):
880        ...     out = 0
881        ...     for i in range(range_num):
882        ...         if i %2 == 0 and i % 7 != 0:
883        ...             out = out + i
884        ...     return out // range_num
885        ...
886        >>> # construct a net and run with GRAPH_MODE.
887        >>> @ms.jit
888        ... def my_func(x):
889        ...     new_shape = for_loop_calculate(100000)
890        ...     return ms.ops.broadcast_to(x, (new_shape, ))
891        ...
892        >>> out = my_func(ms.Tensor([1]))
893        >>> print(out.shape)
894        >>> (21428, )
895    """
896
897    def decorator(fn):
898        """Decorator for ProxyOp."""
899
900        class ProxyOp(PrimitiveWithInfer):
901            """
902            ProxyOp is a temporary operator used to execute the constexpr function.
903            """
904
905            def __init__(self):
906                op_name = name if name else fn.__name__
907                super(ProxyOp, self).__init__(op_name)
908                self.set_const_prim(True)
909                self.fn = fn
910                self.add_prim_attr('constexpr_prim', True)
911                if not reuse_result:
912                    self.add_prim_attr('forbid_reuse_result', True)
913
914            def __infer__(self, *args):
915                value_args = []
916                for item in args:
917                    item_value = item["value"]
918                    if _check_contains_variable(item["dtype"], item_value) and check:
919                        logger.warning("The \"" + self.name + "\" is a constexpr function." \
920                                                              " The input arguments must be all constant value.")
921                    value_args.append(item_value)
922                return {'dtype': None, 'shape': None, 'value': fn(*value_args)}
923
924            def __call__(self, *args, **kwargs):
925                return fn(*args, **kwargs)
926
927        if get_instance:
928            return ProxyOp()
929        return ProxyOp
930
931    if fn is not None:
932        return decorator(fn)
933    return decorator
934
935
936def _primexpr(fn=None, get_instance=True, name=None, reuse_result=True):
937    """
938    _primexpr is similar as constexpr except that when the input to the function decorated by _primexpr contains
939    variable, the function will be compiled as graph.
940
941    _primexpr is only for internal use.
942
943    Args:
944        fn (function): A `fn` use as the infer_value of the output operator. Default: ``None`` .
945        get_instance (bool): If ``True`` , return the instance of operator,
946                             otherwise return the operator class. Default: ``True`` .
947        name (str): Defines the operator name. If `name` is ``None`` , use the function name as op name.
948                             Default: ``None`` .
949        reuse_result (bool): If ``True`` , the operator will be executed once and reuse the result next time,
950                             otherwise the operator will always be executed. Default: ``True`` .
951    """
952
953    def deco(fn):
954        """Decorator for CompileOp."""
955
956        class CompileOp(PrimitiveWithInfer):
957            """
958            CompileOp is a temporary operator used to execute the constexpr function.
959            """
960
961            def __init__(self):
962                op_name = name if name else fn.__name__
963                PrimitiveWithInfer.__init__(self, op_name)
964                self.set_const_prim(True)
965                self.fn = fn
966                self.add_prim_attr('constexpr_prim', True)
967                if not reuse_result:
968                    self.add_prim_attr('forbid_reuse_result', True)
969
970            def __infer__(self, *args):
971                value_args = []
972                for item in args:
973                    if _check_contains_variable(item["dtype"], item["value"]):
974                        return {'dtype': None, 'shape': None, 'value': None, 'fn': (fn,)}
975                    value_args.append(item["value"])
976                return {'dtype': None, 'shape': None, 'value': fn(*value_args)}
977
978            def __call__(self, *args, **kwargs):
979                return fn(*args, **kwargs)
980
981        if get_instance:
982            return CompileOp()
983        return CompileOp
984
985    if fn is not None:
986        return deco(fn)
987    return deco
988
989
990class _RunOpHook:
991    """Hook for run op"""
992
993    current = None
994
995    def __init__(self, hook):
996        self.hook = hook
997        self.old = _RunOpHook.current
998
999    def __enter__(self):
1000        _RunOpHook.current = self
1001        return self
1002
1003    def __exit__(self, *err):
1004        _RunOpHook.current = self.old
1005
1006
1007def _run_op(obj, op_name, args):
1008    """Single op execution function supported by ge in PyNative mode."""
1009    if not _RunOpHook.current:
1010        stub = _pynative_executor.run_op_async(obj, op_name, args)
1011        return _convert_stub(stub)
1012    return _RunOpHook.current.hook(obj, args)
1013
1014
1015@_wrap_func
1016def _run_op_sync(obj, op_name, args):
1017    """Single op execution function in synchronous mode."""
1018    output = _pynative_executor.real_run_op(obj, op_name, args)
1019    return output
1020
1021
1022class _PrimitiveC(Primitive):
1023    def __init__(self, name, attrs):
1024        super().__init__(name)
1025        for key, value in attrs.items():
1026            super().add_prim_attr(key, value)
1027
1028
1029def _get_primitivec(name, attrs):
1030    return _PrimitiveC(name, attrs)
1031
1032
1033def _create_primitive_function_obj():
1034    return PrimitiveFunction_()
1035