• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# This is the Python adaptation and derivative work of Myia (https://github.com/mila-iqia/myia/).
2#
3# Copyright 2020-2021 Huawei Technologies Co., Ltd
4#
5# Licensed under the Apache License, Version 2.0 (the "License");
6# you may not use this file except in compliance with the License.
7# You may obtain a copy of the License at
8#
9# http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing, software
12# distributed under the License is distributed on an "AS IS" BASIS,
13# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14# See the License for the specific language governing permissions and
15# limitations under the License.
16# ============================================================================
17
18"""Basic composite operations."""
19from functools import partial
20from types import FunctionType
21
22from mindspore import context
23from ..._c_expression import EnvInstance_, GradOperation_, HyperMap_, Map_, MultitypeFuncGraph_, Tail_, \
24    TupleAdd_, TupleSlice_, UnpackCall_, ZipOperation_, ListAppend_, TupleGetItemTensor_
25from ...common import dtype as mstype
26from ...common.api import ms_function, _pynative_executor, _wrap_func
27from ..primitive import Primitive
28from ..operations import _grad_ops
29from .. import operations as P
30from .. import signature as sig
31
32__all__ = [EnvInstance_, TupleAdd_, TupleSlice_, UnpackCall_, TupleGetItemTensor_]
33
34
35def add_flags(fn=None, **flags):
36    """
37    A decorator that adds a flag to the function.
38
39    Note:
40        Only supports bool value.
41
42    Args:
43        fn (Function): Function or cell to add flag. Default: None.
44        flags (dict): Flags use kwargs. Default: None.
45
46    Returns:
47        Function, the function with added flags.
48
49    Examples:
50        >>> net = Net();
51        >>> net = add_flags(net, predit=True)
52        >>> print(hasattr(net, '_mindspore_flags'))
53        True
54    """
55    def deco(fn):
56        # need set the attr and access on c++
57        if not hasattr(fn, "_mindspore_flags"):
58            fn._mindspore_flags = {}
59
60        fn._mindspore_flags.update({**flags})
61        return fn
62    ret = deco
63    if fn is not None:
64        ret = deco(fn)
65    return ret
66
67
68def core(fn=None, **flags):
69    """
70    A decorator that adds a flag to the function.
71
72    By default, the function is marked as True, enabling to use this decorator to
73    set flag to a graph.
74
75    Args:
76        fn (Function): Function to add flag. Default: None.
77        flags (dict): The following flags can be set core, which indicates that this is a core function or
78                      other flag. Default: None.
79
80    Supported Platforms:
81        ``Ascend`` ``GPU`` ``CPU``
82
83    Examples:
84        >>> net = Net()
85        >>> net = core(net, predit=True)
86        >>> print(hasattr(net, '_mindspore_flags'))
87        True
88    """
89    # need set the attr and access on c++
90
91    def deco(fn):
92        fn._mindspore_flags = {
93            'core': True,
94            **flags,
95        }
96        return fn
97
98    if fn is not None:
99        ret = deco(fn)
100    else:
101        ret = deco
102    return ret
103
104
105class GradOperation(GradOperation_):
106    """
107    A higher-order function which is used to generate the gradient function for the input function.
108
109    The gradient function generated by `GradOperation` higher-order function can be customized by
110    construction arguments.
111
112    Given an input function `net = Net()` that takes `x` and `y` as inputs, and has a parameter `z`,
113    see `Net` in Examples.
114
115
116    To generate a gradient function that returns gradients with respect to the first input
117    (see `GradNetWrtX` in Examples).
118
119    1. Construct a `GradOperation` higher-order function with default arguments:
120       `grad_op = GradOperation()`.
121
122    2. Call it with input function as argument to get the gradient function: `gradient_function = grad_op(net)`.
123
124    3. Call the gradient function with input function's inputs to get the gradients with respect to the first input:
125       `grad_op(net)(x, y)`.
126
127
128    To generate a gradient function that returns gradients with respect to all inputs (see `GradNetWrtXY` in Examples).
129
130    1. Construct a `GradOperation` higher-order function with `get_all=True` which
131       indicates getting gradients with respect to all inputs, they are `x` and `y` in example function `Net()`:
132       `grad_op = GradOperation(get_all=True)`.
133
134    2. Call it with input function as argument to get the gradient function: `gradient_function = grad_op(net)`.
135
136    3. Call the gradient function with input function's inputs to get the gradients with respect to all inputs:
137       `gradient_function(x, y)`.
138
139    To generate a gradient function that returns gradients with respect to given parameters
140    (see `GradNetWithWrtParams` in Examples).
141
142    1. Construct a `GradOperation` higher-order function with `get_by_list=True`:
143       `grad_op = GradOperation(get_by_list=True)`.
144
145    2. Construct a `ParameterTuple` that will be passed to the input function when constructing
146       `GradOperation` higher-order function, it will be used as a parameter filter that determine
147       which gradient to return: `params = ParameterTuple(net.trainable_params())`.
148
149    3. Call it with input function and `params` as arguments to get the gradient function:
150       `gradient_function = grad_op(net, params)`.
151
152    4. Call the gradient function with input function's inputs to get the gradients with
153       respect to given parameters: `gradient_function(x, y)`.
154
155    To generate a gradient function that returns gradients with respect to all inputs and given parameters
156    in the format of ((dx, dy), (dz))(see `GradNetWrtInputsAndParams` in Examples).
157
158    1. Construct a `GradOperation` higher-order function with `get_all=True` and `get_by_list=True`:
159       `grad_op = GradOperation(get_all=True, get_by_list=True)`.
160
161    2. Construct a `ParameterTuple` that will be passed along input function when constructing
162       `GradOperation` higher-order function: `params = ParameterTuple(net.trainable_params())`.
163
164    3. Call it with input function and `params` as arguments to get the gradient function:
165       `gradient_function = grad_op(net, params)`.
166
167    4. Call the gradient function with input function's inputs
168       to get the gradients with respect to all inputs and given parameters: `gradient_function(x, y)`.
169
170
171    We can configure the sensitivity(gradient with respect to output) by setting `sens_param` as True and
172    passing an extra sensitivity input to the gradient function, the sensitivity input should has the
173    same shape and type with input function's output(see `GradNetWrtXYWithSensParam` in Examples).
174
175    1. Construct a `GradOperation` higher-order function with `get_all=True` and `sens_param=True`:
176       `grad_op = GradOperation(get_all=True, sens_param=True)`.
177
178    2. Define `grad_wrt_output` as `sens_param` which works as the gradient with respect to output:
179       `grad_wrt_output = Tensor(np.ones([2, 2]).astype(np.float32))`.
180
181    3. Call it with input function as argument to get the gradient function:
182       `gradient_function = grad_op(net)`.
183
184    4. Call the gradient function with input function's inputs and `sens_param` to
185       get the gradients with respect to all inputs:
186       `gradient_function(x, y, grad_wrt_output)`.
187
188    Args:
189        get_all (bool): If True, get all the gradients with respect to inputs. Default: False.
190        get_by_list (bool): If True, get all the gradients with respect to Parameter variables.
191            If get_all and get_by_list are both False, get the gradient with respect to first input.
192            If get_all and get_by_list are both True, get the gradients with respect to inputs and Parameter variables
193            at the same time in the form of ((gradients with respect to inputs),
194            (gradients with respect to parameters)). Default: False.
195        sens_param (bool): Whether to append sensitivity (gradient with respect to output) as input.
196            If sens_param is False, a 'ones_like(outputs)' sensitivity will be attached automatically.
197            Default: False.
198            If the sensor_param is True, a sensitivity (gradient with respect to output) needs to be transferred
199            through the location parameter or key-value pair parameter. If the value is transferred through
200            the key-value pair parameter, the key must be sens.
201
202    Returns:
203        The higher-order function which takes a function as argument and returns gradient function for it.
204
205    Raises:
206        TypeError: If `get_all`, `get_by_list` or `sens_param` is not a bool.
207
208    Supported Platforms:
209        ``Ascend`` ``GPU`` ``CPU``
210
211    Examples:
212        >>> from mindspore import ParameterTuple
213        >>> class Net(nn.Cell):
214        ...     def __init__(self):
215        ...         super(Net, self).__init__()
216        ...         self.matmul = P.MatMul()
217        ...         self.z = Parameter(Tensor(np.array([1.0], np.float32)), name='z')
218        ...     def construct(self, x, y):
219        ...         x = x * self.z
220        ...         out = self.matmul(x, y)
221        ...         return out
222        ...
223        >>> class GradNetWrtX(nn.Cell):
224        ...     def __init__(self, net):
225        ...         super(GradNetWrtX, self).__init__()
226        ...         self.net = net
227        ...         self.grad_op = GradOperation()
228        ...     def construct(self, x, y):
229        ...         gradient_function = self.grad_op(self.net)
230        ...         return gradient_function(x, y)
231        ...
232        >>> x = Tensor([[0.5, 0.6, 0.4], [1.2, 1.3, 1.1]], dtype=mstype.float32)
233        >>> y = Tensor([[0.01, 0.3, 1.1], [0.1, 0.2, 1.3], [2.1, 1.2, 3.3]], dtype=mstype.float32)
234        >>> output = GradNetWrtX(Net())(x, y)
235        >>> print(output)
236        [[1.4100001 1.5999999 6.6      ]
237         [1.4100001 1.5999999 6.6      ]]
238        >>>
239        >>> class GradNetWrtXY(nn.Cell):
240        ...     def __init__(self, net):
241        ...         super(GradNetWrtXY, self).__init__()
242        ...         self.net = net
243        ...         self.grad_op = GradOperation(get_all=True)
244        ...     def construct(self, x, y):
245        ...         gradient_function = self.grad_op(self.net)
246        ...         return gradient_function(x, y)
247        >>>
248        >>> x = Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=mstype.float32)
249        >>> y = Tensor([[0.11, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=mstype.float32)
250        >>> output = GradNetWrtXY(Net())(x, y)
251        >>> print(output)
252        (Tensor(shape=[2, 3], dtype=Float32, value=
253        [[ 4.50999975e+00,  2.70000005e+00,  3.60000014e+00],
254         [ 4.50999975e+00,  2.70000005e+00,  3.60000014e+00]]), Tensor(shape=[3, 3], dtype=Float32, value=
255        [[ 2.59999990e+00,  2.59999990e+00,  2.59999990e+00],
256         [ 1.89999998e+00,  1.89999998e+00,  1.89999998e+00],
257         [ 1.30000007e+00,  1.30000007e+00,  1.30000007e+00]]))
258        >>>
259        >>> class GradNetWrtXYWithSensParam(nn.Cell):
260        ...     def __init__(self, net):
261        ...         super(GradNetWrtXYWithSensParam, self).__init__()
262        ...         self.net = net
263        ...         self.grad_op = GradOperation(get_all=True, sens_param=True)
264        ...         self.grad_wrt_output = Tensor([[0.1, 0.6, 0.2], [0.8, 1.3, 1.1]], dtype=mstype.float32)
265        ...     def construct(self, x, y):
266        ...         gradient_function = self.grad_op(self.net)
267        ...         return gradient_function(x, y, self.grad_wrt_output)
268        >>>
269        >>> x = Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=mstype.float32)
270        >>> y = Tensor([[0.11, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=mstype.float32)
271        >>> output = GradNetWrtXYWithSensParam(Net())(x, y)
272        >>> print(output)
273        (Tensor(shape=[2, 3], dtype=Float32, value=
274        [[ 2.21099997e+00,  5.09999990e-01,  1.49000001e+00],
275         [ 5.58800030e+00,  2.68000007e+00,  4.07000017e+00]]), Tensor(shape=[3, 3], dtype=Float32, value=
276        [[ 1.51999998e+00,  2.81999993e+00,  2.14000010e+00],
277         [ 1.09999990e+00,  2.04999995e+00,  1.54999995e+00],
278         [ 9.00000036e-01,  1.54999995e+00,  1.25000000e+00]]))
279        >>>
280        >>> class GradNetWithWrtParams(nn.Cell):
281        ...     def __init__(self, net):
282        ...         super(GradNetWithWrtParams, self).__init__()
283        ...         self.net = net
284        ...         self.params = ParameterTuple(net.trainable_params())
285        ...         self.grad_op = GradOperation(get_by_list=True)
286        ...     def construct(self, x, y):
287        ...         gradient_function = self.grad_op(self.net, self.params)
288        ...         return gradient_function(x, y)
289        >>>
290        >>> x = Tensor([[0.8, 0.6, 0.2], [1.8, 1.3, 1.1]], dtype=mstype.float32)
291        >>> y = Tensor([[0.11, 3.3, 1.1], [1.1, 0.2, 1.4], [1.1, 2.2, 0.3]], dtype=mstype.float32)
292        >>> output = GradNetWithWrtParams(Net())(x, y)
293        >>> print(output)
294        (Tensor(shape=[1], dtype=Float32, value= [ 2.15359993e+01]),)
295        >>>
296        >>> class GradNetWrtInputsAndParams(nn.Cell):
297        ...     def __init__(self, net):
298        ...         super(GradNetWrtInputsAndParams, self).__init__()
299        ...         self.net = net
300        ...         self.params = ParameterTuple(net.trainable_params())
301        ...         self.grad_op = GradOperation(get_all=True, get_by_list=True)
302        ...     def construct(self, x, y):
303        ...         gradient_function = self.grad_op(self.net, self.params)
304        ...         return gradient_function(x, y)
305        >>>
306        >>> x = Tensor([[0.1, 0.6, 1.2], [0.5, 1.3, 0.1]], dtype=mstype.float32)
307        >>> y = Tensor([[0.12, 2.3, 1.1], [1.3, 0.2, 2.4], [0.1, 2.2, 0.3]], dtype=mstype.float32)
308        >>> output = GradNetWrtInputsAndParams(Net())(x, y)
309        >>> print(output)
310        ((Tensor(shape=[2, 3], dtype=Float32, value=
311        [[ 3.51999998e+00,  3.90000010e+00,  2.59999990e+00],
312         [ 3.51999998e+00,  3.90000010e+00,  2.59999990e+00]]), Tensor(shape=[3, 3], dtype=Float32, value=
313        [[ 6.00000024e-01,  6.00000024e-01,  6.00000024e-01],
314         [ 1.89999998e+00,  1.89999998e+00,  1.89999998e+00],
315         [ 1.30000007e+00,  1.30000007e+00,  1.30000007e+00]])), (Tensor(shape=[1], dtype=Float32, value=
316         [ 1.29020004e+01]),))
317    """
318
319    def __init__(self, get_all=False, get_by_list=False, sens_param=False):
320        """Initialize GradOperation."""
321        if not isinstance(get_all, bool):
322            raise TypeError(f"For 'GradOperation', the 'get_all' should be bool, but got {type(get_all).__name__}")
323        if not isinstance(get_by_list, bool):
324            raise TypeError(f"For 'GradOperation', the 'get_by_list' should be bool, "
325                            f"but got {type(get_by_list).__name__}")
326        if not isinstance(sens_param, bool):
327            raise TypeError(f"For 'GradOperation', the 'sens_param' should be bool, "
328                            f"but got {type(sens_param).__name__}")
329        self.get_all = get_all
330        self.get_by_list = get_by_list
331        self.sens_param = sens_param
332        GradOperation_.__init__(self, 'grad', get_all, get_by_list, sens_param)
333        self.grad_fn = None
334        self.fn = None
335        self.pynative_ = False
336
337    def _pynative_forward_run(self, grad, args, kwargs, fn):
338        """ Pynative forward run to build grad graph. """
339        new_kwargs = kwargs
340        if self.sens_param:
341            if not 'sens' in kwargs.keys():
342                args = args[:-1]
343            else:
344                new_kwargs = kwargs.copy()
345                new_kwargs.pop('sens')
346        if isinstance(fn, FunctionType):
347            if not _pynative_executor.check_run(grad, fn, *args, **new_kwargs):
348                _pynative_executor.set_grad_flag(True)
349                _pynative_executor.new_graph(fn, *args, **new_kwargs)
350                output = fn(*args, **new_kwargs)
351                _pynative_executor.end_graph(fn, output, *args, **new_kwargs)
352        else:
353            # Check if fn have run already
354            if not _pynative_executor.check_run(grad, fn, *args, **new_kwargs):
355                fn.set_grad()
356                fn(*args, **new_kwargs)
357                fn.set_grad(False)
358
359    def __call__(self, fn, weights=None):
360        if self.grad_fn is not None and self.fn == fn:
361            return self.grad_fn
362        grad_ = GradOperation(self.get_all, self.get_by_list, self.sens_param)
363        # If calling Grad in GRAPH_MODE or calling Grad in ms_function, do grad in GRAPH_MODE
364        # If calling Grad in pure PYNATIVE_MODE do grad in PYNATIVE_MODE
365        #   In pure PYNATIVE_MODE the out layer after_grad just used to set pynative flag for inner GradOperation.
366        #   In PYNATIVE_MODE calling Grad from ms_function, use the out layer after_grad do grad in GRAPH_MODE.
367        if context.get_context("mode") == context.GRAPH_MODE:
368            if self.get_by_list:
369                @ms_function
370                def after_grad(*args):
371                    return grad_(fn, weights)(*args)
372            else:
373                @ms_function
374                def after_grad(*args):
375                    return grad_(fn)(*args)
376        elif self.pynative_:
377            @_wrap_func
378            def after_grad(*args, **kwargs):
379                if _pynative_executor.check_graph(fn, *args, **kwargs):
380                    print("Another grad step is running")
381                self._pynative_forward_run(grad_, args, kwargs, fn)
382                _pynative_executor.grad(grad_, fn, weights, *args, **kwargs)
383                out = _pynative_executor(fn, *args, **kwargs)
384                _pynative_executor.clear_grad(fn, *args, **kwargs)
385                return out
386        else:
387            grad_.pynative_ = True
388            # after_grad of this branch can't use @ms_function, just directly call grad_
389            if self.get_by_list:
390                def after_grad(*args, **kwargs):
391                    return grad_(fn, weights)(*args, **kwargs)
392            else:
393                def after_grad(*args, **kwargs):
394                    return grad_(fn)(*args, **kwargs)
395
396        self.grad_fn = after_grad
397        self.fn = fn
398        return self.grad_fn
399
400
401class MultitypeFuncGraph(MultitypeFuncGraph_):
402    """
403    Generates overloaded functions.
404
405    MultitypeFuncGraph is a class used to generate overloaded functions, considering different types as inputs.
406    Initialize an `MultitypeFuncGraph` object with name, and use `register` with input types as the decorator
407    for the function to be registered. And the object can be called with different types of inputs,
408    and work with `HyperMap` and `Map`.
409
410    Args:
411        name (str): Operator name.
412        read_value (bool): If the registered function not need to set value on Parameter,
413            and all inputs will pass by value, set `read_value` to True. Default: False.
414
415    Raises:
416        ValueError: If failed to find find a matching function for the given arguments.
417
418    Supported Platforms:
419        ``Ascend`` ``GPU`` ``CPU``
420
421    Examples:
422        >>> # `add` is a metagraph object which will add two objects according to
423        >>> # input type using ".register" decorator.
424        >>> from mindspore import Tensor
425        >>> from mindspore import ops
426        >>> from mindspore import dtype as mstype
427        >>>
428        >>> tensor_add = ops.Add()
429        >>> add = MultitypeFuncGraph('add')
430        >>> @add.register("Number", "Number")
431        ... def add_scala(x, y):
432        ...     return x + y
433        >>> @add.register("Tensor", "Tensor")
434        ... def add_tensor(x, y):
435        ...     return tensor_add(x, y)
436        >>> output = add(1, 2)
437        >>> print(output)
438        3
439        >>> output = add(Tensor([0.1, 0.6, 1.2], dtype=mstype.float32), Tensor([0.1, 0.6, 1.2], dtype=mstype.float32))
440        >>> print(output)
441        [0.2 1.2 2.4]
442    """
443
444    def __init__(self, name, read_value=False):
445        """Initialize MultitypeFuncGraph."""
446        MultitypeFuncGraph_.__init__(self, name)
447        self.entries = list()
448        if read_value:
449            self.set_signatures((
450                sig.make_sig('args', sig.sig_rw.RW_READ, sig.sig_kind.KIND_VAR_POSITIONAL),))
451
452    def __call__(self, *args):
453        if len(self.entries) == 1:
454            output = self.entries[0][1](*args)
455            return output
456        types = tuple(map(mstype.get_py_obj_dtype, args))
457        for sigs, fn in self.entries:
458            if len(sigs) != len(types):
459                continue
460            if any(not mstype.issubclass_(type_, sig) for sig, type_ in zip(sigs, types)):
461                continue
462            output = fn(*args)
463            return output
464        raise ValueError(f"For 'MultitypeFuncGraph', cannot find fn match given args. Got (sigs, fn): {self.entries}, "
465                         f"and (dtype, args): {types}.")
466
467    def register(self, *type_names):
468        """
469        Register a function for the given type string.
470
471        Args:
472            type_names (Union[str, :class:`mindspore.dtype`]): Inputs type names or types list.
473
474        Return:
475            decorator, a decorator to register the function to run, when called under the
476            types described in `type_names`.
477        """
478        def deco(fn):
479            def convert_type(type_input):
480                if isinstance(type_input, str):
481                    return mstype.typing.str_to_type(type_input)
482                if not isinstance(type_input, mstype.Type):
483                    raise TypeError(f"For 'MultitypeFuncGraph', register only support str or {mstype.Type}, but got "
484                                    f"'type_input': {type_input}.")
485                return type_input
486
487            types = tuple(map(convert_type, type_names))
488            self.register_fn(type_names, fn)
489            self.entries.append((types, fn))
490            return fn
491        return deco
492
493
494class HyperMap(HyperMap_):
495    """
496    Hypermap will apply the set operation to input sequences.
497
498    Apply the operations to every elements of the sequence or nested sequence. Different
499    from `Map`, the `HyperMap` supports to apply on nested structure.
500
501    Args:
502        ops (Union[MultitypeFuncGraph, None]): `ops` is the operation to apply. If `ops` is `None`,
503            the operations should be put in the first input of the instance. Default is None.
504        reverse (bool): The optimizer needs to be inverted in some scenarios to improve parallel performance,
505          general users please ignore. `reverse` is the flag to decide if apply the operation reversely.
506          Only supported in graph mode. Default is False.
507
508    Inputs:
509        - **args** (Tuple[sequence]) - If `ops` is not `None`, all the inputs should be sequences with the same length.
510          And each row of the sequences will be the inputs of the operation.
511
512          If `ops` is `None`, the first input is the operation, and the others are inputs.
513
514    Outputs:
515        Sequence or nested sequence, the sequence of output after applying the function.
516        e.g. `operation(args[0][i], args[1][i])`.
517
518    Raises:
519        TypeError: If `ops` is neither MultitypeFuncGraph nor None.
520        TypeError: If `args` is not a Tuple.
521
522    Supported Platforms:
523        ``Ascend`` ``GPU`` ``CPU``
524
525    Examples:
526        >>> from mindspore import dtype as mstype
527        >>> nest_tensor_list = ((Tensor(1, mstype.float32), Tensor(2, mstype.float32)),
528        ...                     (Tensor(3, mstype.float32), Tensor(4, mstype.float32)))
529        >>> # square all the tensor in the nested list
530        >>>
531        >>> square = MultitypeFuncGraph('square')
532        >>> @square.register("Tensor")
533        ... def square_tensor(x):
534        ...     return ops.square(x)
535        >>>
536        >>> common_map = HyperMap()
537        >>> output = common_map(square, nest_tensor_list)
538        >>> print(output)
539        ((Tensor(shape=[], dtype=Float32, value= 1), Tensor(shape=[], dtype=Float32, value= 4)),
540        (Tensor(shape=[], dtype=Float32, value= 9), Tensor(shape=[], dtype=Float32, value= 16)))
541        >>> square_map = HyperMap(square, False)
542        >>> output = square_map(nest_tensor_list)
543        >>> print(output)
544        ((Tensor(shape=[], dtype=Float32, value= 1), Tensor(shape=[], dtype=Float32, value= 4)),
545        (Tensor(shape=[], dtype=Float32, value= 9), Tensor(shape=[], dtype=Float32, value= 16)))
546    """
547
548    def __init__(self, ops=None, reverse=False):
549        """Initialize HyperMap."""
550        self.ops = ops
551        if ops:
552            HyperMap_.__init__(self, reverse, ops)
553        else:
554            HyperMap_.__init__(self, reverse)
555
556    def __call__(self, *args):
557        func = self.ops
558        args_list = args
559        hypermap = self
560        if self.ops is None:
561            func = args[0]
562            args_list = args[1:]
563            hypermap = partial(self, func)
564        # is leaf
565        if not isinstance(args_list[0], (tuple, list)):
566            return func(*args_list)
567        return tuple(map(hypermap, *args_list))
568
569
570class Map(Map_):
571    """
572    Map will apply the set operation on input sequences.
573
574    Apply the operations to every elements of the sequence.
575
576    Args:
577        ops (Union[MultitypeFuncGraph, None]): `ops` is the operation to apply. If `ops` is `None`,
578            the operations should be put in the first input of the instance. Default: None
579        reverse (bool): The optimizer needs to be inverted in some scenarios to improve parallel performance,
580          general users please ignore. `reverse` is the flag to decide if apply the operation reversely.
581          Only supported in graph mode. Default is False.
582
583    Inputs:
584        - **args** (Tuple[sequence]) - If `ops` is not `None`, all the inputs should be the same length sequences,
585          and each row of the sequences. e.g. If args length is 2, and for `i` in length of each sequence
586          `(args[0][i], args[1][i])` will be the input of the operation.
587
588          If `ops` is `None`, the first input is the operation, and the other is inputs.
589
590    Outputs:
591        Sequence, the sequence of output after applying the function. e.g. `operation(args[0][i], args[1][i])`.
592
593    Examples:
594        >>> from mindspore import dtype as mstype
595        >>> tensor_list = (Tensor(1, mstype.float32), Tensor(2, mstype.float32), Tensor(3, mstype.float32))
596        >>> # square all the tensor in the list
597        >>>
598        >>> square = MultitypeFuncGraph('square')
599        >>> @square.register("Tensor")
600        ... def square_tensor(x):
601        ...     return ops.square(x)
602        >>>
603        >>> common_map = Map()
604        >>> output = common_map(square, tensor_list)
605        >>> print(output)
606        (Tensor(shape=[], dtype=Float32, value= 1), Tensor(shape=[], dtype=Float32, value= 4),
607        Tensor(shape=[], dtype=Float32, value= 9))
608        >>> square_map = Map(square, False)
609        >>> output = square_map(tensor_list)
610        >>> print(output)
611        (Tensor(shape=[], dtype=Float32, value= 1), Tensor(shape=[], dtype=Float32, value= 4),
612        Tensor(shape=[], dtype=Float32, value= 9))
613    """
614
615    def __init__(self, ops=None, reverse=False):
616        """Initialize Map."""
617        self.ops = ops
618        if ops:
619            Map_.__init__(self, reverse, ops)
620        else:
621            Map_.__init__(self, reverse)
622
623    def __call__(self, *args):
624        func = self.ops
625        args_list = args
626        if self.ops is None:
627            func = args[0]
628            args_list = args[1:]
629        return tuple(map(func, *args_list))
630
631
632class _ListAppend(ListAppend_):
633    """
634    A metafuncgraph class that append one element to list.
635
636    Args:
637        name (str): The name of the metafuncgraph object.
638    """
639
640    def __init__(self, name):
641        """Initialize _ListAppend."""
642        ListAppend_.__init__(self, name)
643
644    def __call__(self, *args):
645        pass
646
647
648_append = _ListAppend("append")
649
650
651class _Tail(Tail_):
652    """
653    A metafuncgraph class that generates tail elements of the tuple.
654
655    Args:
656        name (str): The name of the metafuncgraph object.
657    """
658
659    def __init__(self, name):
660        """Initialize _Tail."""
661        Tail_.__init__(self, name)
662
663    def __call__(self, *args):
664        pass
665
666
667tail = _Tail('tail')
668
669
670class _ZipOperation(ZipOperation_):
671    """Generates a tuple of zip iterations for inputs."""
672
673    def __init__(self, name):
674        """Initialize _ZipOperation."""
675        ZipOperation_.__init__(self, name)
676
677    def __call__(self, *args):
678        pass
679
680
681zip_operation = _ZipOperation('zip_operation')
682"""`zip_operation` will generate a tuple of zip iterations of inputs."""
683
684
685env_get = MultitypeFuncGraph("env_get")
686
687
688env_getitem = Primitive('env_getitem')
689ref_to_embed = _grad_ops.RefToEmbed()
690zeros_like = P.ZerosLike()
691
692
693@env_get.register("EnvType", "Tensor")
694def _tensor_env_get(env, parameter):
695    """Used to get env."""
696    return env_getitem(env, ref_to_embed(parameter), zeros_like(parameter))
697