1# Copyright 2020 Huawei Technologies Co., Ltd 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# http://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14# ============================================================================ 15"""builtin_operations""" 16import numpy as np 17from mindspore.ops import functional as F 18from mindspore.ops import composite as C 19from mindspore.common.tensor import Tensor 20import mindspore.common.dtype as mstype 21from mindspore.common.dtype import dtype_to_nptype, get_py_obj_dtype 22 23 24def ScalarAdd(x, y): 25 """Implement `scalar_add`.""" 26 return x + y 27 28 29def ScalarMul(x, y): 30 """Implement `scalar_mul`.""" 31 return x * y 32 33 34def ScalarMod(x, y): 35 """Implement `scalar_mul`.""" 36 return x % y 37 38 39def ScalarSub(x, y): 40 """Implement `scalar_sub`.""" 41 return x - y 42 43 44def ScalarUsub(x): 45 """Implement `scalar_usub`.""" 46 return -x 47 48 49def TupleGetItem(x, index): 50 """Implement `tuple_getitem`.""" 51 if isinstance(x, Tensor): 52 x = x.asnumpy() 53 y = x[index] 54 return Tensor(y) 55 return x[index] 56 57 58def scalar_gt(x, y): 59 """Implement `scalar_gt`.""" 60 return x > y 61 62 63def scalar_ne(x, y): 64 """Implement `scalar_ne`.""" 65 return x != y 66 67 68def scalar_eq(x, y): 69 """Implement `scalar_eq`.""" 70 return x == y 71 72 73def scalar_le(x, y): 74 """Implement `scalar_le`.""" 75 return x <= y 76 77 78def scalar_lt(x, y): 79 """Implement `scalar_lt`.""" 80 return x < y 81 82 83def identity(x): 84 """Implement `identity`.""" 85 return x 86 87 88def zeros_like_tensor(x): 89 """Implement `zeros_like_tensor`.""" 90 x = x.asnumpy() 91 value = Tensor(np.zeros(x.shape).astype(np.float32)) 92 return value 93 94 95def Switch(c, x, y): 96 """Implement `switch`.""" 97 return x if c else y 98 99 100def list_getitem(data, item): 101 """Implement `list_getitem`.""" 102 return data[item] 103 104 105def bool_not(x): 106 """Implement `bool_not`.""" 107 return not x 108 109 110def bool_and(x, y): 111 """Implement `bool_and`.""" 112 return x and y 113 114 115def bool_or(x, y): 116 """Implement `bool_or`.""" 117 return x or y 118 119 120def make_list(*xs): 121 """Implement `make_list`.""" 122 return list(xs) 123 124 125def list_len(x): 126 """Implement `list_len`.""" 127 return len(x) 128 129 130def Depend(value, expr): 131 """Implement `Depend`.""" 132 return value 133 134 135def UpdateState(monad, *exprs): 136 """Implement `UpdateState`.""" 137 return monad 138 139 140def Load(value, u=None): 141 """Implement `Load`.""" 142 return value 143 144 145# only used in PyNative mode 146def make_ref(key, value, ref): 147 return value 148 149 150def scalar_cast(x, t): 151 """Implement scalar_cast.""" 152 np_type = dtype_to_nptype(t) 153 value = np_type(x) 154 cast_value = np.ndarray.item(value) 155 return cast_value 156 157 158def typeof(x): 159 """Implement typeof.""" 160 return get_py_obj_dtype(x) 161 162 163def tuple_to_array(x): 164 """Implement `tuple_to_array`.""" 165 return Tensor(np.array(x)) 166 167 168def stop_gradient(x): 169 """Implement `stop_gradient`.""" 170 return x 171 172 173hyper_map = C.HyperMap() 174 175 176def mixed_precision_cast(dst_type, x): 177 """Implement `mixed_precision_cast`.""" 178 179 def cast_inner(data): 180 if isinstance(data, Tensor) and data.dtype in (mstype.float32, mstype.float16, mstype.float64): 181 return F.cast(data, dst_type) 182 return data 183 184 return hyper_map(cast_inner, x) 185