• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020-2021 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""primitive"""
17import inspect
18import copy
19from mindspore.common.api import _wrap_func
20from mindspore import context, log as logger
21from mindspore.parallel._utils import _is_in_auto_parallel_mode
22from .._c_expression import Primitive_, real_run_op, prim_type
23from .._checkparam import Validator
24from . import signature as sig
25
26
27class Primitive(Primitive_):
28    """
29    Primitive is the base class of operator primitives in python.
30
31    Args:
32        name (str): Name for the current Primitive.
33
34    Examples:
35        >>> add = Primitive('add')
36        >>>
37        >>> # or work with prim_attr_register:
38        >>> # init a Primitive class with attr1 and attr2
39        >>> class Add(Primitive):
40        ...     @prim_attr_register
41        ...     def __init__(self, attr1, attr2):
42        ...         '''init for add'''
43        ...     # check attr1 and attr2 or do some initializations
44        ...     # init a Primitive obj with attr1=1 and attr2=2
45        >>> add = Add(attr1=1, attr2=2)
46    """
47    _repr_ignore_list = ['input_names', 'output_names']
48
49    def __init__(self, name):
50        self.name = name
51        self.attrs = {}
52        self.init_attrs = {"name": name}
53        self._update_parameter = False
54        Primitive_.__init__(self, name)
55        if hasattr(self.__class__, '__mindspore_signature__'):
56            out = self._fill_signature(self.__class__.__mindspore_signature__)
57            self.set_signatures(out)
58
59    def _fill_signature(self, signatures):
60        """fills signature."""
61        signatures_new = []
62        for signature in signatures:
63            if isinstance(signature, sig.Signature):
64                signatures_new.append(signature)
65            elif isinstance(signature, sig.sig_dtype):
66                signatures_new.append(sig.make_sig(dtype=signature))
67            else:
68                if len(signature) < 3:
69                    raise ValueError(f"[Internal Error]Signature for one parameter len must > 3, but {signature}")
70                signatures_new.append(sig.make_sig(*signature))
71        return tuple(signatures_new)
72
73    def _clone(self):
74        """
75        Deeply clones the primitive object.
76
77        Calls the __init__() method with the same arguments. This method is called in parser if the
78        flag self.__setattr_flag__ is True.
79        """
80        cloned = copy.deepcopy(self)
81        init_params = inspect.getfullargspec(cloned.__init__.decorated_func).args[1:]
82        init_args = {}
83        for name in init_params:
84            value = self.attrs[name]
85            init_args[name] = value
86        # __init__ should be called to construct cpp object.
87        cloned.__init__(**init_args)
88        for name in self.attrs:
89            value = self.attrs[name]
90            cloned.add_prim_attr(name, value)
91        if hasattr(self, 'instance_name'):
92            cloned.set_prim_instance_name(self.instance_name)
93        return cloned
94
95    def add_prim_attr(self, name, value):
96        """
97        Add primitive attribute.
98
99        Args:
100            name (str): Attribute Name.
101            value (Any): Attribute value.
102
103        Examples:
104            >>> import mindspore.ops as ops
105            >>> a = ops.Add()
106            >>> a = a.add_prim_attr("attr",1)
107            >>> out = a.attrs["attr"]
108            >>> print(out)
109            1
110        """
111        self.__dict__[name] = value
112        self.attrs[name] = value
113        self.add_attr(name, value)
114        return self
115
116    def del_prim_attr(self, name):
117        """
118        Delete primitive attribute.
119
120        Args:
121            name (str): Attribute Name.
122        Examples:
123            >>> import mindspore.ops as ops
124            >>> a = ops.Add()
125            >>> a = a.add_prim_attr("attr",1)
126            >>> a = a.del_prim_attr("attr")
127            >>> print(a.attrs)
128            {'input_names': ['x', 'y'], 'output_names' : ['output']}
129        """
130        if name in self.__dict__ and name in self.attrs:
131            del self.__dict__[name]
132            del self.attrs[name]
133            self.del_attr(name)
134        return self
135
136    def set_stage(self, stage):
137        """
138        Add stage id to primitive attribute.
139
140        Note:
141            It is valid only in semi auto parallel.
142            In other parallel modes, please set it to be 0.
143        Args:
144            stage (int): The stage id for the current operation.
145        Examples:
146            >>> from mindspore import ops
147            >>> add = ops.Add()
148            >>> print(add.set_stage(0))
149            Prim[Add]<stage=0>
150        """
151        self.add_prim_attr("stage", stage)
152        return self
153
154    def shard(self, strategy):
155        """
156        Add strategies to primitive attribute.
157
158        Note:
159            It is valid only in semi auto parallel or auto parallel mode.
160            In other parallel modes, strategies set here will be ignored.
161
162        Args:
163            strategy (tuple): Strategy describes the distributed parallel mode of the current primitive.
164        Examples:
165            >>> from mindspore import ops
166            >>> add = ops.Add()
167            >>> print(add.shard(((1, 1), (1, 1))))
168            Prim[Add]<strategy=((1, 1), (1, 1))>
169        """
170        mode = context.get_auto_parallel_context("parallel_mode")
171        if strategy is not None:
172            if not isinstance(strategy, tuple):
173                raise TypeError(f'strategy must be tuple type, but got:{type(strategy)}')
174            for ele in strategy:
175                if not isinstance(ele, tuple):
176                    raise TypeError(f'The element of strategy must be tuple type, but got:{type(ele)}')
177        if not _is_in_auto_parallel_mode() and strategy:
178            logger.warning(f"The shard strategy {strategy} of {self.name} is not valid in {mode}. "
179                           f"Please use semi auto or auto parallel mode.")
180        self.add_prim_attr("strategy", strategy)
181        return self
182
183    def set_prim_instance_name(self, instance_name):
184        """
185        Set instance name to primitive operator.
186
187        Note:
188            It will be called by default when user defines primitive operator.
189
190        Args:
191            instance_name (str): Instance name of primitive operator set by user.
192        Examples:
193            >>> import mindspore.ops as ops
194            >>> a = ops.Add()
195            >>> a.set_prim_instance_name("add")
196            >>> print(a.instance_name)
197            add
198        """
199        self.set_instance_name(instance_name)
200        self.instance_name = instance_name
201        return self
202
203    def __getattr__(self, item):
204        if item == 'infer_dynamic_shape':
205            return None
206        if item in super().get_attr_dict():
207            return super().get_attr_dict()[item]
208        if item in self.attrs:
209            return self.attrs[item]
210        raise AttributeError(item)
211
212    def check_elim(self, *args):
213        """
214        Check if the primitive can be eliminated. Subclass in need should override this method.
215
216        Args:
217            args(Primitive args): Same as arguments of current Primitive.
218
219        Returns:
220            A tuple consisting of two elements.
221            The first element means if the primitive can be calculated in compiling stage,
222            the second element is calculated result.
223
224        Examples:
225            >>> class AddN(Primitive):
226            ...     @prim_attr_register
227            ...     def __init__(self):
228            ...         self.init_prim_io_names(inputs=["inputs"], outputs=["sum"])
229            ...     def check_elim(self, inputs):
230            ...         if len(inputs) != 1:
231            ...             return (False, None)
232            ...         if isinstance(inputs[0], Tensor):
233            ...             return (True, inputs[0])
234            ...
235            >>> addn = AddN()
236            >>> input_x = Tensor(np.array([1, 2, 3]), mindspore.float32)
237            >>> output = addn.check_elim((input_x,))
238            >>> print(output)
239            (True, Tensor(shape=[3], dtype=Float32, value= [ 1.00000000e+00,  2.00000000e+00,  3.00000000e+00]))
240    """
241        return (False, None)
242
243    def __call__(self, *args):
244        should_elim, output = self.check_elim(*args)
245        if should_elim:
246            return output
247        return _run_op(self, self.name, args)
248
249    def __getstate__(self):
250        return self.__dict__
251
252    def __setstate__(self, d):
253        self.__dict__.update(d)
254
255    def __deepcopy__(self, memo):
256        return type(self)(**self.init_attrs)
257
258    def __repr__(self):
259        attr = ', '.join([f'{k}={self.attrs[k]}' for k in self.attrs if not k in Primitive._repr_ignore_list])
260        info_str = f'Prim[{self.name}]'
261        if attr:
262            info_str += f'<{attr}>'
263        return info_str
264
265    def init_prim_io_names(self, inputs, outputs):
266        """
267        Initialize the name of inputs and outputs of Tensor or attributes.
268
269        Args:
270            inputs (list[str]): list of inputs names.
271            outputs (list[str]): list of outputs names.
272        Examples:
273            >>> import mindspore.ops as ops
274            >>> a = ops.Add()
275            >>> a.init_prim_io_names(["x","y"],["sum"])
276            >>> print(a.input_names)
277            ['x','y']
278            >>> print(a.output_names)
279            ['sum']
280        """
281        # for checking para names with kernel implementation
282        self.add_prim_attr("input_names", inputs)
283        # for checking output number with kernel implementation
284        self.add_prim_attr("output_names", outputs)
285
286    @property
287    def update_parameter(self):
288        """Return whether the primitive will update the value of parameter."""
289        return self._update_parameter
290
291    def recompute(self, mode=True):
292        """
293        Set the primitive recomputed. If a primitive set recomputed feeds into some backward nodes
294        for computing gradient, rather than storing the intermediate activation computed in forward
295        pass, we will recompute it in backward pass.
296
297        Note:
298
299            - If the computation involves something like randomization or global variable, the equivalence
300              is not guaranteed currently.
301            - Not supported in pynative mode
302
303        Args:
304            mode (bool): Specifies whether the primitive is recomputed. Default: True.
305        Examples:
306            >>> import numpy as np
307            >>> import mindspore as ms
308            >>> from mindspore import Tensor, ops, nn
309            >>> class NetRecompute(nn.Cell):
310            ...     def __init__(self):
311            ...         super(NetRecompute,self).__init__()
312            ...         self.relu = ops.ReLU().recompute()
313            ...         self.sqrt = ops.Sqrt()
314            ...     def construct(self, x):
315            ...         out = self.relu(x)
316            ...         return self.sqrt(out)
317            ...
318            >>> class GradNet(nn.Cell):
319            ...     def __init__(self, network):
320            ...         super(GradNet,self).__init__()
321            ...         self.network = network
322            ...         self.grad = ops.GradOperation()
323            ...     def construct(self, x):
324            ...         g_out = self.grad(self.network)(x)
325            ...         return g_out
326            ...
327            >>> x = Tensor(np.array([-1,1]).astype(np.float32))
328            >>> net = NetRecompute()
329            >>> grad = GradNet(net)
330            >>> a = grad(x)
331            >>> print(a)
332            [0. 0.5]
333        """
334        if context.get_context("mode") == context.PYNATIVE_MODE:
335            raise TypeError("Recompute is not supported in pynative mode currently.")
336        Validator.check_bool(mode)
337        self.add_prim_attr("recompute", mode)
338        return self
339
340
341class PrimitiveWithCheck(Primitive):
342    """
343    PrimitiveWithCheck is the base class of primitives in python defines functions for checking operator
344    input arguments but used the infer method registered in c++ source codes.
345
346    There are three methods can be override to define the check logic of the primitive: __check__(), check_shape(),
347    check_dtype(). If __check__() is defined in primitive, the __check__() has highest priority to be called.
348    If __check__() is not defined, check_shape() and check_dtype() can be defined to describe the check logic of
349    the shape and type. Method infer_value() can also be defined (such as PrimitiveWithInfer) for constant propagation.
350
351    Args:
352        name (str): Name of the current Primitive.
353
354    Supported Platforms:
355        ``Ascend`` ``GPU`` ``CPU``
356
357    Examples:
358        >>> # init a Primitive class with check
359        >>> class Flatten(PrimitiveWithCheck):
360        ...     @prim_attr_register
361        ...     def __init__(self):
362        ...         pass
363        ...     def check_shape(self, input_x):
364        ...         validator.check_int(len(input_x), 1, Rel.GE, 'input_x rank', self.name)
365        ...
366        ...     def check_dtype(self, input_x):
367        ...         validator.check_subclass("input_x", input_x, mstype.tensor, self.name)
368        ...
369        >>> # init a Primitive obj
370        >>> add = Flatten()
371    """
372
373    def __init__(self, name):
374        Primitive.__init__(self, name)
375        self.set_prim_type(prim_type.py_infer_check)
376
377    def _clone(self):
378        """
379        Deeply clones the primitive object.
380
381        Calls the __init__() method with the same arguments. This method is called in parser if the
382        flag self.__setattr_flag__ is True.
383        """
384        cloned_prim = Primitive._clone(self)
385        return cloned_prim
386
387    def check_shape(self, *args):
388        """
389        Check shapes of input args.
390
391        Note:
392            The shape of scalar is an empty tuple.
393
394        Args:
395            args (tuple(int)): shapes of input tensors.
396
397        Return:
398            None.
399        """
400        return None
401
402    def check_dtype(self, *args):
403        """
404        Check data types of input args.
405
406        Args:
407            args (:class:`mindspore.dtype`): data type of inputs.
408
409        Return:
410            None.
411        """
412        return None
413
414    def __check__(self, *args):
415        """Checking the input shape and the input type of ops is valid """
416        tracks = ['dtype', 'shape']
417        for track in tracks:
418            fn = getattr(self, 'check_' + track)
419            fn(*(x[track] for x in args))
420
421
422class PrimitiveWithInfer(Primitive):
423    """
424    PrimitiveWithInfer is the base class of primitives in python and defines functions for tracking inference
425    in python.
426
427    There are four method can be override to define the infer logic of the primitive: __infer__(), infer_shape(),
428    infer_dtype(), and infer_value(). If __infer__() is defined in primitive, the __infer__() has highest priority
429    to be called. If __infer__() is not defined, infer_shape() and infer_dtype() can be defined to describe the infer
430    logic of the shape and type. The infer_value() is used for constant propagation.
431
432    Args:
433        name (str): Name of the current Primitive.
434
435    Supported Platforms:
436        ``Ascend`` ``GPU`` ``CPU``
437
438    Examples:
439        >>> # init a Primitive class with infer
440        >>> class Add(PrimitiveWithInfer):
441        ...     @prim_attr_register
442        ...     def __init__(self):
443        ...         pass
444        ...
445        ...     def infer_shape(self, x, y):
446        ...         return x # output shape same as first input 'x'
447        ...
448        ...     def infer_dtype(self, x, y):
449        ...         return x # output type same as first input 'x'
450        ...
451        >>> # init a Primitive obj
452        >>> add = Add()
453    """
454
455    def __init__(self, name):
456        Primitive.__init__(self, name)
457        self.set_prim_type(prim_type.py_infer_shape)
458
459    def _clone(self):
460        """
461        Deeply clones the primitive object.
462
463        Calls the __init__() method with the same arguments. This method is called in parser if the
464        flag self.__setattr_flag__ is True.
465        """
466        cloned_prim = Primitive._clone(self)
467        return cloned_prim
468
469    def infer_shape(self, *args):
470        """
471        Infer output shape based on input shape.
472
473        Note:
474            The shape of scalar is an empty tuple.
475
476        Args:
477            args (tuple(int)): shapes of input tensors.
478
479        Return:
480            `tuple(int)`, shapes of output tensors.
481        """
482        return None
483
484    def infer_dtype(self, *args):
485        """
486        Infer output dtype based on input dtype.
487
488        Args:
489            args (:class:`mindspore.dtype`): data type of inputs.
490
491        Return:
492            :class:`mindspore.dtype`, data type of outputs.
493        """
494        return None
495
496    def infer_value(self, *args):
497        """
498        Infer output value based on input value at compile time.
499
500        Args:
501            args (Any): value of inputs.
502
503        Return:
504            Value of outputs. Return `None`, the value can not be inferred at compile time in this case.
505        """
506        return None
507
508    def __infer__(self, *args):
509        """Infer shape, type, and value at the same time by using dictionary as arguments."""
510        is_graph_mode = context.get_context("mode") == context.GRAPH_MODE
511        fn_infer_dynamic_shape = getattr(self, 'infer_dynamic_shape', None)
512        if is_graph_mode and fn_infer_dynamic_shape is not None:
513            out = fn_infer_dynamic_shape(*args)
514            tracks = ['dtype', 'value']
515            for track in tracks:
516                fn = getattr(self, 'infer_' + track)
517                # fn may return None
518                out[track] = fn(*(x[track] for x in args))
519            return out
520
521        tracks = ['dtype', 'shape', 'value']
522        out = {}
523        for track in tracks:
524            fn = getattr(self, 'infer_' + track)
525            # fn may return None
526            out[track] = fn(*(x[track] for x in args))
527
528        # in non-graph_mode, it is not necessary to infer min/max shape
529        if not is_graph_mode:
530            return out
531
532        # output does not contain dynamic shape, no need to calculate min/max shape
533        def has_dynamic_shape(shp):
534            if isinstance(shp, int):
535                return shp < 0
536            if isinstance(shp, (list, tuple)):
537                return any(has_dynamic_shape(e) for e in shp)
538            return False
539
540        if not has_dynamic_shape(out['shape']):
541            return out
542
543        # calculate min/max shape for output
544        def get_specified_shape(elems, attr):
545            has_specified_shape = False
546            ret_vals = []
547            for elem in elems:
548                if attr in elem:
549                    has_specified_shape = True
550                    ret_vals.append(elem[attr])
551                else:
552                    ret_vals.append(elem['shape'])
553            return has_specified_shape, tuple(ret_vals)
554
555        has_min_shape, min_shapes = get_specified_shape(args, 'min_shape')
556        has_max_shape, max_shapes = get_specified_shape(args, 'max_shape')
557        if not (has_min_shape or has_max_shape):
558            return out
559        if has_min_shape and has_max_shape:
560            fn_infer_min_shape = getattr(self, 'infer_shape')
561            fn_infer_max_shape = fn_infer_min_shape
562            if hasattr(self, 'infer_min_shape'):
563                fn_infer_min_shape = getattr(self, 'infer_min_shape')
564            if hasattr(self, 'infer_max_shape'):
565                fn_infer_max_shape = getattr(self, 'infer_max_shape')
566            out['min_shape'] = fn_infer_min_shape(*min_shapes)
567            out['max_shape'] = fn_infer_max_shape(*max_shapes)
568            return out
569        raise ValueError('Input args has invalid dynamic shape, args info: {args}')
570
571
572def prim_attr_register(fn):
573    """
574    Primitive attributes register.
575
576    Register the decorator of the built-in operator primitive '__init__'.
577    The function will add all the parameters of '__init__' as operator attributes ,
578    and init primtive name.
579
580    Args:
581        fn (function): __init__ function of primitive.
582
583    Returns:
584        function, original function.
585
586    Examples:
587        >>> class MatMul(PrimitiveWithCheck):
588        ...     @prim_attr_register
589        ...     def __init__(self, transpose_a=False, transpose_b=False):
590        ...         self.init_prim_io_names(inputs=['x1', 'x2'], outputs=['output'])
591        ...
592        >>> # init a Primitive obj
593        >>> matmul = MatMul()
594    """
595
596    def deco(self, *args, **kwargs):
597        class_name = self.__class__.__name__
598        if hasattr(self.__class__, "substitute_name"):
599            class_name = self.__class__.substitute_name
600        if isinstance(self, PrimitiveWithInfer):
601            PrimitiveWithInfer.__init__(self, class_name)
602        elif isinstance(self, PrimitiveWithCheck):
603            PrimitiveWithCheck.__init__(self, class_name)
604        else:
605            Primitive.__init__(self, self.__class__.__name__)
606        bound_args = inspect.signature(fn).bind(self, *args, **kwargs)
607        bound_args.apply_defaults()
608        arguments = bound_args.arguments
609        del arguments['self']
610        del self.init_attrs['name']
611        for name in arguments:
612            value = arguments[name]
613            self.add_prim_attr(name, value)
614            self.init_attrs[name] = value
615        fn(self, *args, **kwargs)
616
617    deco.decorated_func = fn
618    return deco
619
620
621def constexpr(fn=None, get_instance=True, name=None):
622    """
623    Creates a PrimitiveWithInfer operator that can infer the value at compile time. We can use it to define a function
624    to compute constant value using the constants in the constructor.
625
626    Args:
627        fn (function): A `fn` use as the infer_value of the output operator. Default: None.
628        get_instance (bool): If true, return the instance of operator,
629                             otherwise return the operator class. Default: True.
630        name (str): Defines the operator name. If `name` is None, use the function name as op name. Default: None.
631
632    Examples:
633        >>> from mindspore.ops import constexpr
634        >>> a = (1, 2)
635        >>> # make an operator to calculate tuple len
636        >>> @constexpr
637        >>> def tuple_len(x):
638        ...     return len(x)
639        ...
640        >>> print(tuple_len(a))
641        2
642        >>> # make an operator class to calculate tuple len
643        >>> @constexpr(get_instance=False, name="TupleLen")
644        >>> def tuple_len_class(x):
645        ...     return len(x)
646        ...
647        >>> print(tuple_len_class()(a))
648        2
649    """
650
651    def deco(fn):
652        """Decorator for CompileOp."""
653
654        class CompileOp(PrimitiveWithInfer):
655            """
656            CompileOp is a temporary operator used to execute the constexpr function.
657            """
658
659            def __init__(self):
660                op_name = name if name else fn.__name__
661                PrimitiveWithInfer.__init__(self, op_name)
662                self.set_const_prim(True)
663
664            def infer_value(self, *args):
665                return fn(*args)
666
667            def __call__(self, *args, **kwargs):
668                return fn(*args)
669
670        if get_instance:
671            return CompileOp()
672        return CompileOp
673
674    if fn is not None:
675        return deco(fn)
676    return deco
677
678
679@_wrap_func
680def _run_op(obj, op_name, args):
681    """Single op execution function supported by ge in PyNative mode."""
682    output = real_run_op(obj, op_name, args)
683    return output
684