• Home
  • Line#
  • Scopes#
  • Navigate#
  • Raw
  • Download
1# Copyright 2020 Huawei Technologies Co., Ltd
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# ============================================================================
15
16"""control_ops"""
17from ..primitive import PrimitiveWithInfer, prim_attr_register
18from ..._checkparam import Validator as validator
19from ...common import dtype as mstype
20
21
22class GeSwitch(PrimitiveWithInfer):
23    """
24    Adds control switch to data.
25
26    Switch data flows into false or true branch depending on the condition. If the condition is true,
27    the true branch will be activated, or vise verse.
28
29    Inputs:
30        - **data** (Union[Tensor, Number]) - The data to be used for switch control.
31        - **pred** (Tensor) - It must be a scalar whose type is bool and shape is `()`, It is used as condition for
32          switch control.
33    Outputs:
34        tuple. Output is tuple(false_output, true_output). The Elements in the tuple has the same shape of input data.
35        The false_output connects with the false_branch and the true_output connects with the true_branch.
36
37    Raises:
38        TypeError: If `data` is neither a Tensor nor a Number.
39        TypeError: If `pred` is not a Tensor.
40
41    Examples:
42        >>> class Net(nn.Cell):
43        ...     def __init__(self):
44        ...         super(Net, self).__init__()
45        ...         self.square = ops.Square()
46        ...         self.add = ops.Add()
47        ...         self.value = Tensor(np.full((1), 3), mindspore.float32)
48        ...         self.switch = ops.GeSwitch()
49        ...         self.merge = ops.Merge()
50        ...         self.less = ops.Less()
51        ...
52        ...     def construct(self, x, y):
53        ...         cond = self.less(x, y)
54        ...         st1, sf1 = self.switch(x, cond)
55        ...         st2, sf2 = self.switch(y, cond)
56        ...         add_ret = self.add(st1, st2)
57        ...         st3, sf3 = self.switch(self.value, cond)
58        ...         sq_ret = self.square(sf3)
59        ...         ret = self.merge((add_ret, sq_ret))
60        ...         return ret[0]
61        ...
62        >>> x = Tensor(10.0, dtype=mindspore.float32)
63        >>> y = Tensor(5.0, dtype=mindspore.float32)
64        >>> net = Net()
65        >>> output = net(x, y)
66        >>> print(output)
67    """
68
69    @prim_attr_register
70    def __init__(self):
71        """Initialize GeSwitch."""
72
73    def __call__(self, data, pred):
74        raise NotImplementedError
75
76    def infer_shape(self, data, pred):
77        validator.check_equal_int(len(pred), 0, "pred rank", self.name)
78        return data, data
79
80    def infer_dtype(self, data_type, pred_type):
81        validator.check_subclass(
82            "data", data_type, (mstype.tensor,) + mstype.number_type, self.name)
83        validator.check_tensor_dtype_valid("pred", pred_type, [mstype.bool_], self.name)
84        return data_type, data_type
85
86
87class Merge(PrimitiveWithInfer):
88    """
89    Merges all input data to one.
90
91    One and only one of the inputs must be selected as the output
92
93    Inputs:
94        - **inputs** (Union(Tuple, List)) - The data to be merged. All tuple elements must have the same data type.
95
96    Outputs:
97        tuple. Output is tuple(`data`, `output_index`). The `data` has the same shape of `inputs` element.
98
99    Raises:
100        TypeError: If `inputs` is neither Tuple nor list.
101
102    Examples:
103        >>> merge = ops.Merge()
104        >>> input_x = Tensor(np.linspace(0, 8, 8).reshape(2, 4), mindspore.float32)
105        >>> input_y = Tensor(np.random.randint(-4, 4, (2, 4)), mindspore.float32)
106        >>> result = merge((input_x, input_y))
107    """
108
109    @prim_attr_register
110    def __init__(self):
111        """Initialize Merge."""
112
113    def __call__(self, *args):
114        raise NotImplementedError
115
116    def infer_shape(self, inputs):
117        return inputs[0], [1]
118
119    def infer_dtype(self, inputs):
120        args = {}
121        for i, item in enumerate(inputs):
122            args['inputs[%d]' % i] = item
123
124        validator.check_scalar_or_tensor_types_same(args, (mstype.bool_,) + mstype.number_type, self.name)
125        return inputs[0], mstype.int32
126