1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16"""control_ops""" 17 18from __future__ import absolute_import 19from mindspore.ops.primitive import PrimitiveWithInfer, prim_attr_register 20from mindspore import _checkparam as validator 21from mindspore.common import dtype as mstype 22 23 24class GeSwitch(PrimitiveWithInfer): 25 """ 26 Adds control switch to data. 27 28 Switch data flows into ``False`` or ``True`` branch depending on the condition. If the condition is ``True`` , 29 the ``True`` branch will be activated, or vise verse. 30 31 Inputs: 32 - **data** (Union[Tensor, Number]) - The data to be used for switch control. 33 - **pred** (Tensor) - It must be a scalar whose type is bool and shape is `()`, It is used as condition for 34 switch control. 35 Outputs: 36 tuple. Output is tuple(false_output, true_output). The Elements in the tuple has the same shape of input data. 37 The false_output connects with the false_branch and the true_output connects with the true_branch. 38 39 Raises: 40 TypeError: If `data` is neither a Tensor nor a Number. 41 TypeError: If `pred` is not a Tensor. 42 43 Examples: 44 >>> class Net(nn.Cell): 45 ... def __init__(self): 46 ... super(Net, self).__init__() 47 ... self.square = ops.Square() 48 ... self.add = ops.Add() 49 ... self.value = Tensor(np.full((1), 3), mindspore.float32) 50 ... self.switch = ops.GeSwitch() 51 ... self.merge = ops.Merge() 52 ... self.less = ops.Less() 53 ... 54 ... def construct(self, x, y): 55 ... cond = self.less(x, y) 56 ... st1, sf1 = self.switch(x, cond) 57 ... st2, sf2 = self.switch(y, cond) 58 ... add_ret = self.add(st1, st2) 59 ... st3, sf3 = self.switch(self.value, cond) 60 ... sq_ret = self.square(sf3) 61 ... ret = self.merge((add_ret, sq_ret)) 62 ... return ret[0] 63 ... 64 >>> x = Tensor(10.0, dtype=mindspore.float32) 65 >>> y = Tensor(5.0, dtype=mindspore.float32) 66 >>> net = Net() 67 >>> output = net(x, y) 68 >>> print(output) 69 """ 70 71 @prim_attr_register 72 def __init__(self): 73 """Initialize GeSwitch.""" 74 75 def __call__(self, data, pred): 76 raise NotImplementedError 77 78 def infer_shape(self, data, pred): 79 validator.check_int_range(len(pred), 0, 1, validator.INC_BOTH, "pred rank", self.name) 80 return data, data 81 82 def infer_dtype(self, data_type, pred_type): 83 validator.check_subclass( 84 "data", data_type, (mstype.tensor_type,) + mstype.number_type, self.name) 85 validator.check_tensor_dtype_valid("pred", pred_type, [mstype.bool_], self.name) 86 return data_type, data_type 87 88 89class Merge(PrimitiveWithInfer): 90 """ 91 Merges all input data to one. 92 93 One and only one of the inputs must be selected as the output 94 95 Inputs: 96 - **inputs** (Union(Tuple, List)) - The data to be merged. All tuple elements must have the same data type. 97 98 Outputs: 99 tuple. Output is tuple(`data`, `output_index`). The `data` has the same shape of `inputs` element. 100 101 Raises: 102 TypeError: If `inputs` is neither Tuple nor list. 103 104 Examples: 105 >>> merge = ops.Merge() 106 >>> input_x = Tensor(np.linspace(0, 8, 8).reshape(2, 4), mindspore.float32) 107 >>> input_y = Tensor(np.random.randint(-4, 4, (2, 4)), mindspore.float32) 108 >>> result = merge((input_x, input_y)) 109 """ 110 111 @prim_attr_register 112 def __init__(self): 113 """Initialize Merge.""" 114 115 def __call__(self, *args): 116 raise NotImplementedError 117 118 def infer_shape(self, inputs): 119 return inputs[0], [1] 120 121 def infer_dtype(self, inputs): 122 args = {} 123 for i, item in enumerate(inputs): 124 args['inputs[%d]' % i] = item 125 126 validator.check_scalar_or_tensor_types_same(args, (mstype.bool_,) + mstype.number_type, self.name) 127 return inputs[0], mstype.int32 128