• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""control_ops"""
17
18from __future__ import absolute_import
19from mindspore.ops.primitive import PrimitiveWithInfer, prim_attr_register
20from mindspore import _checkparam as validator
21from mindspore.common import dtype as mstype
22
23
24class GeSwitch(PrimitiveWithInfer):
25    """
26    Adds control switch to data.
27
28    Switch data flows into ``False`` or ``True`` branch depending on the condition. If the condition is ``True`` ,
29    the ``True`` branch will be activated, or vise verse.
30
31    Inputs:
32        - **data** (Union[Tensor, Number]) - The data to be used for switch control.
33        - **pred** (Tensor) - It must be a scalar whose type is bool and shape is `()`, It is used as condition for
34          switch control.
35    Outputs:
36        tuple. Output is tuple(false_output, true_output). The Elements in the tuple has the same shape of input data.
37        The false_output connects with the false_branch and the true_output connects with the true_branch.
38
39    Raises:
40        TypeError: If `data` is neither a Tensor nor a Number.
41        TypeError: If `pred` is not a Tensor.
42
43    Examples:
44        >>> class Net(nn.Cell):
45        ...     def __init__(self):
46        ...         super(Net, self).__init__()
47        ...         self.square = ops.Square()
48        ...         self.add = ops.Add()
49        ...         self.value = Tensor(np.full((1), 3), mindspore.float32)
50        ...         self.switch = ops.GeSwitch()
51        ...         self.merge = ops.Merge()
52        ...         self.less = ops.Less()
53        ...
54        ...     def construct(self, x, y):
55        ...         cond = self.less(x, y)
56        ...         st1, sf1 = self.switch(x, cond)
57        ...         st2, sf2 = self.switch(y, cond)
58        ...         add_ret = self.add(st1, st2)
59        ...         st3, sf3 = self.switch(self.value, cond)
60        ...         sq_ret = self.square(sf3)
61        ...         ret = self.merge((add_ret, sq_ret))
62        ...         return ret[0]
63        ...
64        >>> x = Tensor(10.0, dtype=mindspore.float32)
65        >>> y = Tensor(5.0, dtype=mindspore.float32)
66        >>> net = Net()
67        >>> output = net(x, y)
68        >>> print(output)
69    """
70
71    @prim_attr_register
72    def __init__(self):
73        """Initialize GeSwitch."""
74
75    def __call__(self, data, pred):
76        raise NotImplementedError
77
78    def infer_shape(self, data, pred):
79        validator.check_int_range(len(pred), 0, 1, validator.INC_BOTH, "pred rank", self.name)
80        return data, data
81
82    def infer_dtype(self, data_type, pred_type):
83        validator.check_subclass(
84            "data", data_type, (mstype.tensor_type,) + mstype.number_type, self.name)
85        validator.check_tensor_dtype_valid("pred", pred_type, [mstype.bool_], self.name)
86        return data_type, data_type
87
88
89class Merge(PrimitiveWithInfer):
90    """
91    Merges all input data to one.
92
93    One and only one of the inputs must be selected as the output
94
95    Inputs:
96        - **inputs** (Union(Tuple, List)) - The data to be merged. All tuple elements must have the same data type.
97
98    Outputs:
99        tuple. Output is tuple(`data`, `output_index`). The `data` has the same shape of `inputs` element.
100
101    Raises:
102        TypeError: If `inputs` is neither Tuple nor list.
103
104    Examples:
105        >>> merge = ops.Merge()
106        >>> input_x = Tensor(np.linspace(0, 8, 8).reshape(2, 4), mindspore.float32)
107        >>> input_y = Tensor(np.random.randint(-4, 4, (2, 4)), mindspore.float32)
108        >>> result = merge((input_x, input_y))
109    """
110
111    @prim_attr_register
112    def __init__(self):
113        """Initialize Merge."""
114
115    def __call__(self, *args):
116        raise NotImplementedError
117
118    def infer_shape(self, inputs):
119        return inputs[0], [1]
120
121    def infer_dtype(self, inputs):
122        args = {}
123        for i, item in enumerate(inputs):
124            args['inputs[%d]' % i] = item
125
126        validator.check_scalar_or_tensor_types_same(args, (mstype.bool_,) + mstype.number_type, self.name)
127        return inputs[0], mstype.int32
128