1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15 16"""control_ops""" 17from ..primitive import PrimitiveWithInfer, prim_attr_register 18from ..._checkparam import Validator as validator 19from ...common import dtype as mstype 20 21 22class GeSwitch(PrimitiveWithInfer): 23 """ 24 Adds control switch to data. 25 26 Switch data flows into false or true branch depending on the condition. If the condition is true, 27 the true branch will be activated, or vise verse. 28 29 Inputs: 30 - **data** (Union[Tensor, Number]) - The data to be used for switch control. 31 - **pred** (Tensor) - It must be a scalar whose type is bool and shape is `()`, It is used as condition for 32 switch control. 33 Outputs: 34 tuple. Output is tuple(false_output, true_output). The Elements in the tuple has the same shape of input data. 35 The false_output connects with the false_branch and the true_output connects with the true_branch. 36 37 Raises: 38 TypeError: If `data` is neither a Tensor nor a Number. 39 TypeError: If `pred` is not a Tensor. 40 41 Examples: 42 >>> class Net(nn.Cell): 43 ... def __init__(self): 44 ... super(Net, self).__init__() 45 ... self.square = ops.Square() 46 ... self.add = ops.Add() 47 ... self.value = Tensor(np.full((1), 3), mindspore.float32) 48 ... self.switch = ops.GeSwitch() 49 ... self.merge = ops.Merge() 50 ... self.less = ops.Less() 51 ... 52 ... def construct(self, x, y): 53 ... cond = self.less(x, y) 54 ... st1, sf1 = self.switch(x, cond) 55 ... st2, sf2 = self.switch(y, cond) 56 ... add_ret = self.add(st1, st2) 57 ... st3, sf3 = self.switch(self.value, cond) 58 ... sq_ret = self.square(sf3) 59 ... ret = self.merge((add_ret, sq_ret)) 60 ... return ret[0] 61 ... 62 >>> x = Tensor(10.0, dtype=mindspore.float32) 63 >>> y = Tensor(5.0, dtype=mindspore.float32) 64 >>> net = Net() 65 >>> output = net(x, y) 66 >>> print(output) 67 """ 68 69 @prim_attr_register 70 def __init__(self): 71 """Initialize GeSwitch.""" 72 73 def __call__(self, data, pred): 74 raise NotImplementedError 75 76 def infer_shape(self, data, pred): 77 validator.check_equal_int(len(pred), 0, "pred rank", self.name) 78 return data, data 79 80 def infer_dtype(self, data_type, pred_type): 81 validator.check_subclass( 82 "data", data_type, (mstype.tensor,) + mstype.number_type, self.name) 83 validator.check_tensor_dtype_valid("pred", pred_type, [mstype.bool_], self.name) 84 return data_type, data_type 85 86 87class Merge(PrimitiveWithInfer): 88 """ 89 Merges all input data to one. 90 91 One and only one of the inputs must be selected as the output 92 93 Inputs: 94 - **inputs** (Union(Tuple, List)) - The data to be merged. All tuple elements must have the same data type. 95 96 Outputs: 97 tuple. Output is tuple(`data`, `output_index`). The `data` has the same shape of `inputs` element. 98 99 Raises: 100 TypeError: If `inputs` is neither Tuple nor list. 101 102 Examples: 103 >>> merge = ops.Merge() 104 >>> input_x = Tensor(np.linspace(0, 8, 8).reshape(2, 4), mindspore.float32) 105 >>> input_y = Tensor(np.random.randint(-4, 4, (2, 4)), mindspore.float32) 106 >>> result = merge((input_x, input_y)) 107 """ 108 109 @prim_attr_register 110 def __init__(self): 111 """Initialize Merge.""" 112 113 def __call__(self, *args): 114 raise NotImplementedError 115 116 def infer_shape(self, inputs): 117 return inputs[0], [1] 118 119 def infer_dtype(self, inputs): 120 args = {} 121 for i, item in enumerate(inputs): 122 args['inputs[%d]' % i] = item 123 124 validator.check_scalar_or_tensor_types_same(args, (mstype.bool_,) + mstype.number_type, self.name) 125 return inputs[0], mstype.int32 126